Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_op_cudnn.cc
1 
17 #include "caffe2/operators/recurrent_op_cudnn.h"
18 #include "caffe2/utils/math.h"
19 
20 #include <map>
21 
22 namespace caffe2 {
23 
24 namespace detail {
25 
26 template <typename T>
27 TensorDescriptors<T>::TensorDescriptors(
28  size_t n,
29  const std::vector<int>& dim,
30  const std::vector<int>& stride) {
31  descs_.resize(n);
32  CAFFE_ENFORCE_EQ(dim.size(), stride.size());
33  for (auto i = 0; i < n; ++i) {
34  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&descs_[i]));
35  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
36  descs_[i],
37  cudnnTypeWrapper<T>::type,
38  dim.size(),
39  dim.data(),
40  stride.data()));
41  }
42 }
43 
44 template <typename T>
45 TensorDescriptors<T>::~TensorDescriptors() {
46  for (auto desc : descs_) {
47  cudnnDestroyTensorDescriptor(desc);
48  }
49 }
50 }
51 
52 template <typename T>
53 RecurrentBaseOp<T>::RecurrentBaseOp(
54  const OperatorDef& operator_def,
55  Workspace* ws)
56  : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
57  CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
58  CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
59  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
60  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
61  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&cxDesc_));
62  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hyDesc_));
63  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&cyDesc_));
64 }
65 
66 template <typename T>
67 RecurrentBaseOp<T>::~RecurrentBaseOp() {
68  CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropoutDesc_));
69  CUDNN_ENFORCE(cudnnDestroyRNNDescriptor(rnnDesc_));
70  CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(wDesc_));
71  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hxDesc_));
72  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(cxDesc_));
73  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hyDesc_));
74  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(cyDesc_));
75 }
76 
77 template <typename T>
78 void RecurrentBaseOp<T>::initialize(
79  const Tensor<CUDAContext>& input,
80  Tensor<CUDAContext>* dropoutStates,
81  Tensor<CUDAContext>* output,
82  Tensor<CUDAContext>* hiddenOutput,
83  Tensor<CUDAContext>* cellOutput) {
84  static_assert(sizeof(T) == 4, ""); // workaround clang bug
85  CAFFE_ENFORCE_GE(input.ndim(), 3);
86  const int seqLength = input.dim(0);
87  const int batchSize = input.dim(1);
88  const int inputDim = input.dim(2);
89  const int hiddenSize = OperatorBase::GetSingleArgument<int>("hidden_size", 0);
90  CAFFE_ENFORCE_GT(hiddenSize, 0);
91  const auto bidirectional =
92  OperatorBase::GetSingleArgument<int>("bidirectional", 0);
93  CAFFE_ENFORCE(bidirectional == 0 || bidirectional == 1);
94  const auto numDirections = bidirectional == 1 ? 2 : 1;
95  const auto outputDim = hiddenSize * numDirections;
96  const auto rnnDirection =
97  bidirectional == 1 ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
98  const auto numLayers = OperatorBase::GetSingleArgument<int>("num_layers", 0);
99  CAFFE_ENFORCE_GT(numLayers, 0);
100  const auto& rnnModeStr =
101  OperatorBase::GetSingleArgument<string>("rnn_mode", "");
102  CAFFE_ENFORCE(rnnModeStr == "lstm" || rnnModeStr == "gru");
103  const auto rnnMode = rnnModeStr == "lstm" ? CUDNN_LSTM : CUDNN_GRU;
104  const auto& rnnInputStr =
105  OperatorBase::GetSingleArgument<string>("input_mode", "");
106  CAFFE_ENFORCE(rnnInputStr == "linear" || rnnInputStr == "skip");
107  const auto rnnInput =
108  rnnInputStr == "linear" ? CUDNN_LINEAR_INPUT : CUDNN_SKIP_INPUT;
109 
110  // Dropout setup
111  {
112  if (dropoutStates) {
113  size_t stateSize;
114  float dropout_param =
115  OperatorBase::GetSingleArgument<float>("dropout", 1.0);
116  if (dropout_param < 1.0) {
117  CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
118  cudnn_wrapper_.inline_cudnn_handle(), &stateSize));
119  dropoutStates->Resize(std::vector<int>{static_cast<int>(
120  stateSize / 4 /* sizeof(T) - workaround clang bug */)});
121  CUDNN_ENFORCE(cudnnSetDropoutDescriptor(
122  dropoutDesc_,
123  cudnn_wrapper_.inline_cudnn_handle(),
124  dropout_param,
125  dropoutStates->template mutable_data<T>(),
126  stateSize,
127  OperatorBase::GetSingleArgument<int>("seed", 0)));
128  }
129  }
130  }
131 
132  // RNN setup
133  {
134 #if CUDNN_VERSION_MIN(7, 0, 0)
135  CUDNN_ENFORCE(cudnnSetRNNDescriptor(
136  cudnn_wrapper_.inline_cudnn_handle(),
137  rnnDesc_,
138  hiddenSize,
139  numLayers,
140  dropoutDesc_,
141  rnnInput,
142  rnnDirection,
143  rnnMode,
144  CUDNN_RNN_ALGO_STANDARD, // TODO: verify correctness / efficiency.
145  cudnnTypeWrapper<T>::type));
146 #else
147  CUDNN_ENFORCE(cudnnSetRNNDescriptor(
148  rnnDesc_,
149  hiddenSize,
150  numLayers,
151  dropoutDesc_,
152  rnnInput,
153  rnnDirection,
154  rnnMode,
155  cudnnTypeWrapper<T>::type));
156 #endif
157  }
158  // X setup
159  {
160  xDesc_.reset(new detail::TensorDescriptors<T>(
161  seqLength,
162  // Third dimension is unused
163  {batchSize, inputDim, 1},
164  // Fully-packed
165  {inputDim, 1, 1}));
166  }
167  // Y setup
168  {
169  yDesc_.reset(new detail::TensorDescriptors<T>(
170  seqLength,
171  // Third dimension is unused
172  {batchSize, hiddenSize * numDirections, 1},
173  // Fully-packed
174  {numDirections * hiddenSize, 1, 1}));
175 
176  if (output) {
177  output->Resize(std::vector<int>{seqLength, batchSize, outputDim});
178  }
179  }
180 
181  // Hidden/Cell setup
182  {
183  const std::array<int, 3> dim{
184  numLayers * numDirections, batchSize, hiddenSize};
185  const std::array<int, 3> stride{batchSize * hiddenSize, hiddenSize, 1};
186  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
187  hxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
188  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
189  cxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
190  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
191  hyDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
192  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
193  cyDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
194 
195  if (hiddenOutput) {
196  hiddenOutput->Resize(
197  std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
198  }
199 
200  if (cellOutput) {
201  cellOutput->Resize(
202  std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
203  }
204  }
205 
206  // Weights setup
207  {
208  size_t weightsSize;
209  CUDNN_ENFORCE(cudnnGetRNNParamsSize(
210  cudnn_wrapper_.inline_cudnn_handle(),
211  rnnDesc_,
212  xDesc_->descs()[0],
213  &weightsSize,
214  cudnnTypeWrapper<T>::type));
215  const std::array<int, 3> dims{
216  static_cast<int>(
217  weightsSize / 4 /* sizeof(T) - workaround clang bug */),
218  1,
219  1};
220  CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
221  wDesc_, cudnnTypeWrapper<T>::type, CUDNN_TENSOR_NCHW, 3, dims.data()));
222  }
223 
224  // RNN workspace size
225  {
226  CUDNN_ENFORCE(cudnnGetRNNWorkspaceSize(
227  cudnn_wrapper_.inline_cudnn_handle(),
228  rnnDesc_,
229  seqLength,
230  xDesc_->descs(),
231  &cudnnWsNbytes_));
232  }
233 }
234 
235 template <typename T>
236 bool RecurrentOp<T>::RunOnDevice() {
237  const int seqLength = Input(INPUT).dim32(0);
238  if (Input(INPUT).dims() != cachedInputDims_) {
239  initialize(
240  Input(INPUT),
241  Output(DROPOUT_STATES),
242  Output(OUTPUT),
243  Output(HIDDEN_OUTPUT),
244  Output(CELL_OUTPUT));
245  cachedInputDims_ = Input(INPUT).dims();
246  }
247 
248  // Validation checks
249  size_t weightsSize;
250  CUDNN_ENFORCE(cudnnGetRNNParamsSize(
251  cudnn_wrapper_.inline_cudnn_handle(),
252  rnnDesc_,
253  xDesc_->descs()[0],
254  &weightsSize,
255  cudnnTypeWrapper<T>::type));
256  CAFFE_ENFORCE_EQ(Input(WEIGHT).nbytes(), weightsSize);
257 
258  // Training reserve size
259  CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
260  cudnn_wrapper_.inline_cudnn_handle(),
261  rnnDesc_,
262  seqLength,
263  xDesc_->descs(),
264  &reserveNbytes_));
265  Output(RNN_SCRATCH)
266  ->Resize(std::vector<int>{static_cast<int>(
267  reserveNbytes_ / 4)}); // sizeof(T) - workaround clang bug
268  Output(RNN_SCRATCH)->template mutable_data<T>();
269 
270  auto InputData = [this](int i) { return this->Input(i).template data<T>(); };
271  auto OutputData = [this](int i) {
272  return this->Output(i)->template mutable_data<T>();
273  };
274 
275  if (OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
276  cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
277  CUDNN_ENFORCE(cudnnRNNForwardInference(
278  state->cudnn_handle(),
279  rnnDesc_,
280  seqLength,
281  xDesc_->descs(),
282  InputData(INPUT), //.template data<T>(),
283  hxDesc_,
284  InputData(HIDDEN_INPUT), //.template data<T>(),
285  cxDesc_,
286  InputData(CELL_INPUT), //.template data<T>(),
287  wDesc_,
288  InputData(WEIGHT), //.template data<T>(),
289  yDesc_->descs(),
290  OutputData(OUTPUT), //->template mutable_data<T>(),
291  hyDesc_,
292  OutputData(HIDDEN_OUTPUT), //->template mutable_data<T>(),
293  cyDesc_,
294  OutputData(CELL_OUTPUT), //->template mutable_data<T>(),
295  state->workspace().get(cudnnWsNbytes_),
296  cudnnWsNbytes_));
297  });
298  } else {
299  cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
300  CUDNN_ENFORCE(cudnnRNNForwardTraining(
301  state->cudnn_handle(),
302  rnnDesc_,
303  seqLength,
304  xDesc_->descs(),
305  InputData(INPUT), //.template data<T>(),
306  hxDesc_,
307  InputData(HIDDEN_INPUT), //.template data<T>(),
308  cxDesc_,
309  InputData(CELL_INPUT), //.template data<T>(),
310  wDesc_,
311  InputData(WEIGHT), //.template data<T>(),
312  yDesc_->descs(),
313  OutputData(OUTPUT), //->template mutable_data<T>(),
314  hyDesc_,
315  OutputData(HIDDEN_OUTPUT), //->template mutable_data<T>(),
316  cyDesc_,
317  OutputData(CELL_OUTPUT), //->template mutable_data<T>(),
318  state->workspace().get(cudnnWsNbytes_),
319  cudnnWsNbytes_,
320  OutputData(RNN_SCRATCH), //->template mutable_data<T>(),
321  reserveNbytes_));
322  });
323  }
324 
325  return true;
326 }
327 
328 template <typename T>
329 bool RecurrentGradientOp<T>::RunOnDevice() {
330  const int seqLength = Input(INPUT).dim32(0);
331  if (Input(INPUT).dims() != cachedInputDims_) {
332  initialize(Input(INPUT), Output(DROPOUT_STATES));
333  cachedInputDims_ = Input(INPUT).dims();
334  }
335  CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
336  cudnn_wrapper_.inline_cudnn_handle(),
337  rnnDesc_,
338  seqLength,
339  xDesc_->descs(),
340  &reserveNbytes_));
341  CAFFE_ENFORCE_EQ(reserveNbytes_, Input(RNN_SCRATCH).nbytes());
342  Output(GRAD_INPUT)->ResizeLike(Input(INPUT));
343  Output(GRAD_HIDDEN_INPUT)->ResizeLike(Input(HIDDEN_INPUT));
344  Output(GRAD_CELL_INPUT)->ResizeLike(Input(CELL_INPUT));
345 
346  Output(GRAD_WEIGHT)->ResizeLike(Input(WEIGHT));
347  math::Set<T, CUDAContext>(
348  Output(GRAD_WEIGHT)->size(),
349  0.0,
350  Output(GRAD_WEIGHT)->template mutable_data<T>(),
351  &context_);
352 
353 #if CUDNN_VERSION_MIN(6,0,0)
354  auto * reserve = Output(RNN_SCRATCH_OUT)->template mutable_data<T>();
355 #else
356  const auto * reserve = Output(RNN_SCRATCH_OUT)->template data<T>();
357 #endif
358  auto InputData = [this](int i) { return this->Input(i).template data<T>(); };
359  auto OutputData = [this](int i) {
360  return this->Output(i)->template mutable_data<T>();
361  };
362 
363  cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
364  CUDNN_ENFORCE(cudnnRNNBackwardData(
365  state->cudnn_handle(),
366  rnnDesc_,
367  seqLength,
368  yDesc_->descs(),
369  InputData(OUTPUT), // Input(OUTPUT).template data<T>(),
370  yDesc_->descs(),
371  InputData(GRAD_OUTPUT), // Input(GRAD_OUTPUT).template data<T>(),
372  hyDesc_,
373  // Note: like CNTK, ignore these gradient inputs. t16675365 to
374  // reconsider.
375  nullptr,
376  cyDesc_,
377  nullptr,
378  wDesc_,
379  InputData(WEIGHT), // Input(WEIGHT).template data<T>(),
380  hxDesc_,
381  InputData(HIDDEN_INPUT), // Input(HIDDEN_INPUT).template data<T>(),
382  cxDesc_,
383  InputData(CELL_INPUT),
384  xDesc_->descs(),
385  OutputData(GRAD_INPUT),
386  hxDesc_,
387  OutputData(GRAD_HIDDEN_INPUT),
388  cxDesc_,
389  OutputData(GRAD_CELL_INPUT),
390  state->workspace().get(cudnnWsNbytes_),
391  cudnnWsNbytes_,
392  reserve,
393  reserveNbytes_));
394  CUDNN_ENFORCE(cudnnRNNBackwardWeights(
395  state->cudnn_handle(),
396  rnnDesc_,
397  seqLength,
398  xDesc_->descs(),
399  InputData(INPUT), // Input(INPUT).template data<T>(),
400  hxDesc_,
401  InputData(HIDDEN_INPUT), // Input(HIDDEN_INPUT).template data<T>(),
402  yDesc_->descs(),
403  InputData(OUTPUT), // Input(OUTPUT).template data<T>(),
404  state->workspace().get(cudnnWsNbytes_),
405  cudnnWsNbytes_,
406  wDesc_,
407  OutputData(
408  GRAD_WEIGHT), // Output(GRAD_WEIGHT)->template mutable_data<T>(),
409  reserve,
410  reserveNbytes_));
411  });
412 
413  return true;
414 }
415 
416 template <typename T, RecurrentParamOpMode mode>
417 bool RecurrentParamAccessOp<T, mode>::RunOnDevice() {
418  initialize(Input(0));
419 
420  if (mode == SET_PARAM) {
421  size_t paramsSize;
422  CUDNN_ENFORCE(cudnnGetRNNParamsSize(
423  cudnn_wrapper_.inline_cudnn_handle(),
424  rnnDesc_,
425  xDesc_->descs()[0],
426  &paramsSize,
427  cudnnTypeWrapper<T>::type));
428 
429  CAFFE_ENFORCE_EQ(
430  paramsSize / 4, Input(1).size(), "Incorrect weight initialization");
431  }
432 
433  int layer = OperatorBase::GetSingleArgument<int>("layer", 0);
434  std::string param_type =
435  OperatorBase::GetSingleArgument<string>("param_type", "");
436  std::string input_type =
437  OperatorBase::GetSingleArgument<string>("input_type", "");
438 
439  // Mapping to CUDNN constants
440  std::map<string, int> weight_constants = {{"input_gate_w", 0},
441  {"forget_gate_w", 1},
442  {"cell_w", 2},
443  {"output_gate_w", 3}};
444  std::map<string, int> bias_constants = {{"input_gate_b", 0},
445  {"forget_gate_b", 1},
446  {"cell_b", 2},
447  {"output_gate_b", 3}};
448  if (bias_constants.find(param_type) != bias_constants.end()) {
449  int param_id = bias_constants[param_type] + 4 * (input_type == "recurrent");
450 
451  cudnnFilterDescriptor_t biasDesc;
452  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&biasDesc));
453  void* bias;
454 
455  CUDNN_ENFORCE(cudnnGetRNNLinLayerBiasParams(
456  cudnn_wrapper_.inline_cudnn_handle(),
457  rnnDesc_,
458  layer,
459  xDesc_->descs()[0],
460  wDesc_,
461  Input(1).template data<T>(),
462  param_id, // Forget gate bias for recurrent input
463  biasDesc,
464  &bias));
465  int numBiasDims;
466  std::vector<int> biasDims(3);
467  cudnnDataType_t dt;
468  cudnnTensorFormat_t tf;
469  // For some reason, the CuDNN Bias tensor is 3 dimensional
470  CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
471  biasDesc, 3, &dt, &tf, &numBiasDims, biasDims.data()));
472  CAFFE_ENFORCE_EQ(numBiasDims, 3);
473 
474  if (mode == SET_PARAM) {
475  CAFFE_ENFORCE_EQ(
476  biasDims[0] * biasDims[1] * biasDims[2], Input(2).size());
477  context_.template Copy<T, CUDAContext, CUDAContext>(
478  biasDims[0] * biasDims[1] * biasDims[2],
479  Input(2).template data<T>(),
480  static_cast<T*>(bias));
481  } else {
482  Output(0)->Resize(biasDims);
483  context_.template Copy<T, CUDAContext, CUDAContext>(
484  biasDims[0] * biasDims[1] * biasDims[2],
485  static_cast<T*>(bias),
486  Output(0)->template mutable_data<T>());
487  }
488  } else if (weight_constants.find(param_type) != weight_constants.end()) {
489  int param_id =
490  weight_constants[param_type] + 4 * (input_type == "recurrent");
491  cudnnFilterDescriptor_t matrixParamDesc;
492  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&matrixParamDesc));
493  void* pmatrix;
494  CUDNN_ENFORCE(cudnnGetRNNLinLayerMatrixParams(
495  cudnn_wrapper_.inline_cudnn_handle(),
496  rnnDesc_,
497  layer,
498  xDesc_->descs()[0],
499  wDesc_,
500  Input(1).template data<T>(),
501  param_id, // Forget gate bias for recurrent input
502  matrixParamDesc,
503  &pmatrix));
504  int numDims;
505  std::vector<int> matDims(3);
506  cudnnDataType_t dt;
507  cudnnTensorFormat_t tf;
508 
509  CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
510  matrixParamDesc, 3, &dt, &tf, &numDims, matDims.data()));
511  CAFFE_ENFORCE_EQ(numDims, 3);
512  if (mode == SET_PARAM) {
513  CAFFE_ENFORCE_EQ(matDims[0] * matDims[1] * matDims[2], Input(2).size());
514  context_.template Copy<T, CUDAContext, CUDAContext>(
515  matDims[0] * matDims[1] * matDims[2],
516  Input(2).template data<T>(),
517  static_cast<T*>(pmatrix));
518  } else {
519  Output(0)->Resize(matDims);
520  context_.template Copy<T, CUDAContext, CUDAContext>(
521  matDims[0] * matDims[1] * matDims[2],
522  static_cast<T*>(pmatrix),
523  Output(0)->template mutable_data<T>());
524  }
525  } else {
526  CAFFE_ENFORCE(false, "Unknown param type:", param_type);
527  }
528 
529  return true;
530 }
531 
532 REGISTER_CUDNN_OPERATOR(Recurrent, RecurrentOp<float>);
533 OPERATOR_SCHEMA(Recurrent).NumInputs(4).NumOutputs(5).SetDoc(R"DOC(
534 
535 Recurrent wraps the CuDNN R5 RNN implementation. See the CuDNN R5
536 documentation for more information.
537 
538 In general, the implementation takes an input (TxNxD) tensor, the
539 hidden state input (NxD), the cell input (NxD), and a weight tensor
540 (effectively an opaque blob, where the size and layout is dictated by
541 CuDNN).
542 
543 The outputs are the output (again, TxNxD), the final hidden/cell
544 states (NxD). These can be reset (at sequence boundaries across
545 minibatches) by multiplying by zero.
546 
547 The CuDNN arguments (hidden_size, bidirectional, num_layers, rnn_mode,
548 input_mode) are passed directly through to CuDNN.
549 
550 )DOC");
551 REGISTER_CUDNN_OPERATOR(RecurrentGradient, RecurrentGradientOp<float>);
552 OPERATOR_SCHEMA(RecurrentGradient)
553  .NumInputs(7)
554  .NumOutputs(6)
555  .AllowInplace({{4, 5}});
556 
557 REGISTER_CUDNN_OPERATOR(
558  RecurrentParamSet,
559  RecurrentParamAccessOp<float, SET_PARAM>);
560 OPERATOR_SCHEMA(RecurrentParamSet)
561  .NumInputs(3)
562  .NumOutputs(1)
563  .EnforceInplace({{1, 0}})
564  .SetDoc("Set individual parameters of a recurrent net.")
565  .Arg("param_type", R"DOC(Type of param to be set:
566  "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w"
567  "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b"
568  )DOC")
569  .Arg("input_type", "'recurrent' or 'input'")
570  .Arg("layer", "layer index (starting from 0)")
571  .Input(0, "input", R"DOC(Input blob. Needed for inferring the shapes.
572  A dummy tensor matching the input shape is ok.)DOC")
573  .Input(1, "all_params", "Blob holding all the parameters")
574  .Input(2, "param", "Values for the specified parameter")
575  .Output(
576  0,
577  "all_params",
578  "Blob holding all the parameters (same as input(1))");
579 
580 REGISTER_CUDNN_OPERATOR(
581  RecurrentParamGet,
582  RecurrentParamAccessOp<float, GET_PARAM>);
583 OPERATOR_SCHEMA(RecurrentParamGet)
584  .NumInputs(2)
585  .NumOutputs(1)
586  .SetDoc("Retrieve individual parameters of a recurrent net op.")
587  .Arg("param_type", R"DOC(Type of param to be set:
588  "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w"
589  "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b"
590  )DOC")
591  .Arg("input_type", "'recurrent' or 'input'")
592  .Arg("layer", "layer index (starting from 0)")
593  .Input(0, "input", R"DOC(Input blob. Needed for inferring the shapes.
594  A dummy tensor matching the input shape is ok.)DOC")
595  .Input(1, "all_params", "Blob holding all the parameters")
596  .Output(0, "param", "Blob holding the requested values");
597 
599  using GradientMakerBase::GradientMakerBase;
600  vector<OperatorDef> GetGradientDefs() override {
601  return SingleGradientDef(
602  "RecurrentGradient",
603  "",
604  vector<string>{I(0), // INPUT
605  I(1), // HIDDEN_INPUT
606  I(2), // CELL_INPUT
607  I(3), // WEIGHT
608  O(3), // RNN_SCRATCH
609  O(0), // OUTPUT
610  GO(0)}, // GRAD_OUTPUT
611  // TODO: not currently using these gradients, investigate t16675365
612  // GO(1), // GRAD_HIDDEN_OUTPUT
613  // GO(2)}, // GRAD_CELL_OUTPUT
614  vector<string>{
615  GI(0), // GRAD_INPUT
616  GI(1), // GRAD_HIDDEN_INPUT
617  GI(2), // GRAD_CELL_INPUT
618  GI(3), // GRAD_WEIGHT
619  O(4), // DROPOUT_STATES
620  O(3) // RNN_SCRATCH
621  });
622  }
623 };
624 REGISTER_GRADIENT(Recurrent, GetRecurrentGradient);
625 }
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.