4 #include <c10/core/Tensor.h> 7 #include <unordered_map> 9 #include "caffe2/core/blob_serialization.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/core/tensor.h" 12 #include "caffe2/utils/eigen_utils.h" 13 #include "caffe2/utils/math.h" 14 #include "lstm_utils.h" 16 C10_DECLARE_CAFFE2_OPERATOR(LSTMOp);
21 using t_tuple = std::tuple<Tensor, Tensor>;
29 CPUContext* _context) {
30 initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context);
33 CellParams(
const CellParams& rhs) {
34 initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
37 CellParams& operator=(
const CellParams& rhs) {
38 initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
47 CPUContext* _context) {
48 w_ih = copy_ctor(_w_ih);
49 w_hh = copy_ctor(_w_hh);
50 b_ih = copy_ctor(_b_ih);
51 b_hh = copy_ctor(_b_hh);
62 return linear(input, w_ih, b_ih, context);
65 return linear(h, w_hh, b_hh, context);
70 explicit LSTMCell(CPUContext* context) : context_(context) {}
73 const t_tuple& hidden,
74 const CellParams& params)
const {
75 const auto& hx = std::get<0>(hidden);
76 const auto& cx = std::get<1>(hidden);
77 auto linear_ih = params.linear_ih(input);
78 auto linear_hh = params.linear_hh(hx);
79 auto gates = add(linear_ih, linear_hh, context_);
80 auto chunked_gates = chunk(gates, 4, 1, context_);
81 auto ingate = sigmoid(chunked_gates[0]);
82 auto forgetgate = sigmoid(chunked_gates[1]);
83 auto cellgate = tanh(chunked_gates[2], context_);
84 auto outgate = sigmoid(chunked_gates[3]);
87 add(mul(forgetgate, cx, context_),
88 mul(ingate, cellgate, context_),
90 auto hy = mul(outgate, tanh(cy, context_), context_);
91 return std::make_tuple(std::move(hy), std::move(cy));
96 template <
typename output_type,
typename h
idden_type>
99 hidden_type final_hidden;
101 LayerOutput(
const output_type& _outputs,
const hidden_type& _hidden) {
102 outputs = copy_ctor(_outputs);
103 final_hidden = copy_ctor(_hidden);
107 template <
typename h
idden_type,
typename param_type>
109 using output_type = LayerOutput<Tensor, hidden_type>;
111 virtual output_type operator()(
113 const hidden_type& input_hidden,
114 const param_type& params)
const = 0;
117 struct FullLSTMLayer : Layer<t_tuple, CellParams> {
118 FullLSTMLayer(LSTMCell& cell, CPUContext* context)
119 : cell_(cell), context_(context) {}
121 LayerOutput<std::vector<Tensor>, t_tuple> operator()(
122 const std::vector<Tensor>& step_inputs,
123 const std::tuple<Tensor, Tensor>& input_hidden,
124 const CellParams& params)
const {
125 std::vector<Tensor> step_outputs;
126 auto hidden = copy_ctor(input_hidden);
128 for (
size_t i = 0; i < step_inputs.size(); i++) {
129 hidden = cell_(step_inputs[i], hidden, params);
130 step_outputs.push_back(copy_ctor(std::get<0>(hidden)));
133 return {step_outputs, hidden};
136 LayerOutput<Tensor, t_tuple> operator()(
138 const std::tuple<Tensor, Tensor>& input_hidden,
139 const CellParams& params)
const override {
140 auto unstacked_output =
141 (*this)(unbind(inputs, 0, context_), input_hidden, params);
142 return {stack(unstacked_output.outputs, 0, context_),
143 unstacked_output.final_hidden};
146 CPUContext* context_;
149 struct FullBidirectionalLSTMLayer
150 : Layer<std::pair<t_tuple, t_tuple>, std::pair<CellParams, CellParams>> {
151 using bidir_hidden_type = std::pair<t_tuple, t_tuple>;
152 using param_type = std::pair<CellParams, CellParams>;
153 using output_type = LayerOutput<Tensor, bidir_hidden_type>;
155 FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context)
156 : layer_(cell, context), context_(context) {}
158 output_type operator()(
160 const bidir_hidden_type& input_hidden,
161 const param_type& params)
const override {
162 std::vector<Tensor> outputs;
163 auto step_inputs = unbind(input, 0, context_);
164 auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
165 auto fw_output = stack(fw_result.outputs, 0, context_);
166 outputs.push_back(copy_ctor(fw_output));
167 auto rev_step_inputs = reverse(std::move(step_inputs));
169 layer_(rev_step_inputs, input_hidden.second, params.second);
170 std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
171 auto rev_output = stack(rev_result.outputs, 0, context_);
172 outputs.push_back(copy_ctor(rev_output));
173 return {cat(outputs, fw_output.dim() - 1, context_),
175 std::move(fw_result.final_hidden),
176 std::move(rev_result.final_hidden))};
179 inline std::vector<Tensor> reverse(std::vector<Tensor>&& x)
const {
180 std::reverse(x.begin(), x.end());
185 FullLSTMLayer layer_;
186 CPUContext* context_;
189 template <
typename h
idden_type,
typename weight_type>
190 LayerOutput<Tensor, std::vector<hidden_type>> apply_layer_stack(
191 const Layer<hidden_type, weight_type>& layer,
193 const std::vector<hidden_type>& hiddens,
194 const std::vector<weight_type>& weights,
195 int64_t num_layers) {
197 num_layers == hiddens.size(),
198 "Expected more hidden states in stacked_rnn");
200 num_layers == weights.size(),
"Expected more weights in stacked_rnn");
202 auto layer_input = input.UnsafeSharedInstance();
203 auto hidden_it = hiddens.begin();
204 auto weight_it = weights.begin();
205 std::vector<hidden_type> final_hiddens(num_layers);
206 for (int64_t l = 0; l < num_layers; ++l) {
207 auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
208 final_hiddens.at(l) = std::move(layer_output.final_hidden);
209 layer_input = std::move(layer_output.outputs);
211 return {layer_input, final_hiddens};
214 std::tuple<Tensor, Tensor, Tensor> _lstm_impl(
216 const std::vector<CellParams>& params,
221 CPUContext* context) {
222 using stack_output = LayerOutput<Tensor, std::vector<t_tuple>>;
223 auto layer_hx = unbind(hx, 0, context);
224 auto layer_cx = unbind(cx, 0, context);
225 int64_t total_layers = layer_hx.size();
226 std::vector<std::tuple<Tensor, Tensor>> hiddens;
227 hiddens.reserve(total_layers);
228 for (int64_t i = 0; i < total_layers; ++i) {
229 hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
231 LSTMCell cell(context);
232 std::shared_ptr<stack_output> stack_output_ptr;
234 auto bidir_result = apply_layer_stack(
235 FullBidirectionalLSTMLayer{cell, context},
240 stack_output_ptr.reset(
new stack_output(
241 bidir_result.outputs,
242 unpair_vec(std::move(bidir_result.final_hidden))));
244 auto result = apply_layer_stack(
245 FullLSTMLayer{cell, context}, input, hiddens, params, num_layers);
246 stack_output_ptr = std::make_shared<stack_output>(std::move(result));
249 std::vector<Tensor> hy, cy;
250 hy.reserve(total_layers);
251 cy.reserve(total_layers);
252 for (
auto& hidden : stack_output_ptr->final_hidden) {
253 hy.push_back(std::move(std::get<0>(hidden)));
254 cy.push_back(std::move(std::get<1>(hidden)));
256 return std::make_tuple(
257 std::move(stack_output_ptr->outputs),
258 stack(hy, 0, context),
259 stack(cy, 0, context));
263 std::vector<CellParams> gather_params(
264 const std::vector<Tensor>& params,
266 CPUContext* context) {
268 std::vector<CellParams> result;
271 params.size() % 4, 0,
"got an incorrect number of LSTM parameters");
272 for (
size_t i = 0; i < params.size(); i += 4) {
274 params[i], params[i + 1], params[i + 2], params[i + 3], context);
278 params.size() % 2, 0,
"got an incorrect number of LSTM parameters");
279 for (
size_t i = 0; i < params.size(); i += 2) {
281 params[i], params[i + 1], undefined, undefined, context);
287 class InferenceLSTMOp :
public Operator<CPUContext> {
289 template <
class... Args>
290 explicit InferenceLSTMOp(Args&&... args)
291 : Operator(
std::forward<Args>(args)...),
292 num_layers_(this->template GetSingleArgument<int64_t>(
"num_layers", 1)),
294 this->template GetSingleArgument<bool>(
"bidirectional", false)),
295 has_biases_(this->template GetSingleArgument<bool>(
"has_biases", true)),
297 this->template GetSingleArgument<bool>(
"batch_first", false)) {}
299 bool RunOnDevice()
override;
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...