Caffe2 - C++ API
A deep learning, cross platform ML framework
inference_lstm_op.cc
1 #include "caffe2/operators/inference_lstm_op.h"
2 
3 namespace caffe2 {
4 namespace {
5 
6 bool InferenceLSTMOp::RunOnDevice() {
7  auto& _input = Input(0);
8  auto& hidden_0 = Input(1);
9  auto& hidden_1 = Input(2);
10  std::vector<Tensor> params;
11  for (int i = 3; i < InputSize(); i++) {
12  params.push_back(Input(i).UnsafeSharedInstance());
13  }
14  auto input = batch_first_ ? transpose(_input, 0, 1, &context_)
15  : _input.UnsafeSharedInstance();
16 
17  auto cell_params = gather_params(params, has_biases_, &context_);
18  auto results = _lstm_impl(
19  input,
20  cell_params,
21  hidden_0,
22  hidden_1,
23  num_layers_,
24  bidirectional_,
25  &context_);
26 
27  std::vector<Tensor> allOutputs(OutputSize());
28  allOutputs.at(0) = copy_ctor(std::get<0>(results));
29  if (batch_first_) {
30  allOutputs.at(0) = transpose(allOutputs.at(0), 0, 1, &context_);
31  }
32  allOutputs.at(1) = copy_ctor(std::get<1>(results));
33  allOutputs.at(2) = copy_ctor(std::get<2>(results));
34  for (int i = 0; i < OutputSize(); i++) {
35  auto output = XOutput(i, allOutputs.at(i).sizes(), dtype<float>());
36  context_.CopyItemsSameDevice(
37  allOutputs.at(i).dtype(),
38  allOutputs.at(i).numel(),
39  allOutputs.at(i).template data<float>(),
40  output.template mutable_data<float>());
41  }
42  return true;
43 }
44 
45 REGISTER_CPU_OPERATOR(InferenceLSTM, InferenceLSTMOp);
46 OPERATOR_SCHEMA(InferenceLSTM)
47  .NumInputs(1, INT_MAX)
48  .NumOutputs(3)
49  .Output(0, "output", "the output of the last layer of lstm")
50  .Output(1, "hidden", "hidden state at t = seq_len")
51  .Output(2, "cell", "cell state at t = seq_len")
52  .Arg("num_layers", "(*long*): number of layers in the lstm stack")
53  .Arg("has_biases", "(*bool*): whether the cells have biases or not")
54  .Arg("batch_first", "(*bool*): whether the batch is at dim 0")
55  .Arg("bidirectional", "(*bool*): if bidirectional");
56 NO_GRADIENT(InferenceLSTM);
57 } // namespace
58 } // namespace caffe2
59 
60 C10_REGISTER_CAFFE2_OPERATOR_CPU(
61  InferenceLSTM,
62  (std::vector<c10::Argument>{
63  c10::Argument("input_list", ListType::ofTensors()),
64  c10::Argument("num_layers", IntType::get()),
65  c10::Argument("has_biases", BoolType::get()),
66  c10::Argument("batch_first", BoolType::get()),
67  c10::Argument("bidirectional", BoolType::get())}),
68  (std::vector<c10::Argument>{c10::Argument("output"),
69  c10::Argument("hidden"),
70  c10::Argument("cell")}),
71  caffe2::InferenceLSTMOp);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13