Caffe2 - C++ API
A deep learning, cross platform ML framework
given_tensor_fill_op.cc
1 
17 #include "caffe2/operators/given_tensor_fill_op.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(GivenTensorFill, GivenTensorFillOp<float, CPUContext>);
22 REGISTER_CPU_OPERATOR(
23  GivenTensorDoubleFill,
24  GivenTensorFillOp<double, CPUContext>);
25 REGISTER_CPU_OPERATOR(GivenTensorBoolFill, GivenTensorFillOp<bool, CPUContext>);
26 REGISTER_CPU_OPERATOR(GivenTensorIntFill, GivenTensorFillOp<int, CPUContext>);
27 REGISTER_CPU_OPERATOR(
28  GivenTensorInt64Fill,
29  GivenTensorFillOp<int64_t, CPUContext>);
30 REGISTER_CPU_OPERATOR(
31  GivenTensorStringFill,
32  GivenTensorFillOp<std::string, CPUContext>);
33 
34 NO_GRADIENT(GivenTensorFill);
35 NO_GRADIENT(GivenTensorDoubleFill);
36 NO_GRADIENT(GivenTensorBoolFill);
37 NO_GRADIENT(GivenTensorIntFill);
38 NO_GRADIENT(GivenTensorInt64Fill);
39 NO_GRADIENT(GivenTensorStringFill);
40 
41 OPERATOR_SCHEMA(GivenTensorFill)
42  .NumInputs(0, 1)
43  .NumOutputs(1)
44  .AllowInplace({{0, 0}})
45  .TensorInferenceFunction(FillerTensorInference<>);
46 OPERATOR_SCHEMA(GivenTensorDoubleFill)
47  .NumInputs(0, 1)
48  .NumOutputs(1)
49  .AllowInplace({{0, 0}})
50  .TensorInferenceFunction(
51  FillerTensorInference<TensorProto_DataType_DOUBLE>);
52 OPERATOR_SCHEMA(GivenTensorBoolFill)
53  .NumInputs(0, 1)
54  .NumOutputs(1)
55  .AllowInplace({{0, 0}})
56  .TensorInferenceFunction(FillerTensorInference<TensorProto_DataType_BOOL>);
57 OPERATOR_SCHEMA(GivenTensorIntFill)
58  .NumInputs(0, 1)
59  .NumOutputs(1)
60  .AllowInplace({{0, 0}})
61  .TensorInferenceFunction(FillerTensorInference<TensorProto_DataType_INT32>);
62 OPERATOR_SCHEMA(GivenTensorInt64Fill)
63  .NumInputs(0, 1)
64  .NumOutputs(1)
65  .AllowInplace({{0, 0}})
66  .TensorInferenceFunction(FillerTensorInference<TensorProto_DataType_INT64>);
67 OPERATOR_SCHEMA(GivenTensorStringFill)
68  .NumInputs(0, 1)
69  .NumOutputs(1)
70  .AllowInplace({{0, 0}})
71  .TensorInferenceFunction(
72  FillerTensorInference<TensorProto_DataType_STRING>);
73 
74 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.