17 #include "caffe2/experiments/operators/tt_contraction_op.h" 21 REGISTER_CPU_OPERATOR(TTContraction, TTContractionOp<float, CPUContext>);
23 OPERATOR_SCHEMA(TTContraction)
27 Tensor contraction C = A * B 29 .Arg("K",
"i_{k-1} * r_k")
30 .Arg(
"M",
"r_{k-1} * o_{k-1}")
32 .Input(0,
"A",
"2D matrix of size (K x M)")
33 .Input(1,
"B",
"tensor")
34 .Output(0,
"C",
"contracted tensor");
36 REGISTER_CPU_OPERATOR(
37 TTContractionGradient,
38 TTContractionGradientOp<float, CPUContext>);
40 OPERATOR_SCHEMA(TTContractionGradient).NumInputs(3).NumOutputs(2);
42 class GetTTContractionGradient :
public GradientMakerBase {
43 using GradientMakerBase::GradientMakerBase;
44 vector<OperatorDef> GetGradientDefs()
override {
45 return SingleGradientDef(
46 "TTContractionGradient",
48 vector<string>{GO(0), I(0), I(1)},
49 vector<string>{GI(0), GI(1)},
54 REGISTER_GRADIENT(TTContraction, GetTTContractionGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...