Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_contraction_op.cc
1 
17 #include "caffe2/experiments/operators/tt_contraction_op.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(TTContraction, TTContractionOp<float, CPUContext>);
22 
23 OPERATOR_SCHEMA(TTContraction)
24  .NumInputs(2)
25  .NumOutputs(1)
26  .SetDoc(R"DOC(
27 Tensor contraction C = A * B
28 )DOC")
29  .Arg("K", "i_{k-1} * r_k")
30  .Arg("M", "r_{k-1} * o_{k-1}")
31  .Arg("N", "o_k")
32  .Input(0, "A", "2D matrix of size (K x M)")
33  .Input(1, "B", "tensor")
34  .Output(0, "C", "contracted tensor");
35 
36 REGISTER_CPU_OPERATOR(
37  TTContractionGradient,
38  TTContractionGradientOp<float, CPUContext>);
39 
40 OPERATOR_SCHEMA(TTContractionGradient).NumInputs(3).NumOutputs(2);
41 
42 class GetTTContractionGradient : public GradientMakerBase {
43  using GradientMakerBase::GradientMakerBase;
44  vector<OperatorDef> GetGradientDefs() override {
45  return SingleGradientDef(
46  "TTContractionGradient",
47  "",
48  vector<string>{GO(0), I(0), I(1)},
49  vector<string>{GI(0), GI(1)},
50  Def().arg());
51  }
52 };
53 
54 REGISTER_GRADIENT(TTContraction, GetTTContractionGradient);
55 
56 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.