1 #include "caffe2/operators/rnn/recurrent_op_cudnn.h" 2 #include "caffe2/utils/math.h" 11 TensorDescriptors<T>::TensorDescriptors(
13 const std::vector<int>& dim,
14 const std::vector<int>& stride) {
16 CAFFE_ENFORCE_EQ(dim.size(), stride.size());
17 for (
auto i = 0; i < n; ++i) {
18 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&descs_[i]));
19 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
21 cudnnTypeWrapper<T>::type,
29 TensorDescriptors<T>::~TensorDescriptors() {
30 for (
auto desc : descs_) {
31 cudnnDestroyTensorDescriptor(desc);
37 RecurrentBaseOp<T>::~RecurrentBaseOp() {
38 CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropoutDesc_));
39 CUDNN_ENFORCE(cudnnDestroyRNNDescriptor(rnnDesc_));
40 CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(wDesc_));
41 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hxDesc_));
45 void RecurrentBaseOp<T>::initialize(
51 static_assert(
sizeof(
T) == 4,
"");
52 CAFFE_ENFORCE_GE(input.dim(), 3);
53 const int seqLength = input.size(0);
54 const int batchSize = input.size(1);
55 const int inputDim = input.size(2);
56 const int hiddenSize = OperatorBase::GetSingleArgument<int>(
"hidden_size", 0);
57 CAFFE_ENFORCE_GT(hiddenSize, 0);
58 const auto bidirectional =
59 OperatorBase::GetSingleArgument<int>(
"bidirectional", 0);
60 CAFFE_ENFORCE(bidirectional == 0 || bidirectional == 1);
61 const auto numDirections = bidirectional == 1 ? 2 : 1;
62 const auto outputDim = hiddenSize * numDirections;
63 const auto rnnDirection =
64 bidirectional == 1 ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
65 const auto numLayers = OperatorBase::GetSingleArgument<int>(
"num_layers", 0);
66 CAFFE_ENFORCE_GT(numLayers, 0);
67 const auto& rnnModeStr =
68 OperatorBase::GetSingleArgument<string>(
"rnn_mode",
"");
69 CAFFE_ENFORCE(rnnModeStr ==
"lstm" || rnnModeStr ==
"gru");
70 const auto rnnMode = rnnModeStr ==
"lstm" ? CUDNN_LSTM : CUDNN_GRU;
71 const auto& rnnInputStr =
72 OperatorBase::GetSingleArgument<string>(
"input_mode",
"");
73 CAFFE_ENFORCE(rnnInputStr ==
"linear" || rnnInputStr ==
"skip");
75 rnnInputStr ==
"linear" ? CUDNN_LINEAR_INPUT : CUDNN_SKIP_INPUT;
82 OperatorBase::GetSingleArgument<float>(
"dropout", 1.0);
83 if (dropout_param < 1.0) {
84 CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
85 cudnn_wrapper_.inline_cudnn_handle(), &stateSize));
86 dropoutStates->Resize(std::vector<int>{
static_cast<int>(
88 CUDNN_ENFORCE(cudnnSetDropoutDescriptor(
90 cudnn_wrapper_.inline_cudnn_handle(),
92 dropoutStates->template mutable_data<T>(),
94 OperatorBase::GetSingleArgument<int>(
"seed", 0)));
101 #if CUDNN_VERSION_MIN(7, 0, 0) 102 CUDNN_ENFORCE(cudnnSetRNNDescriptor(
103 cudnn_wrapper_.inline_cudnn_handle(),
111 CUDNN_RNN_ALGO_STANDARD,
112 cudnnTypeWrapper<T>::type));
114 CUDNN_ENFORCE(cudnnSetRNNDescriptor(
122 cudnnTypeWrapper<T>::type));
127 xDesc_.reset(
new detail::TensorDescriptors<T>(
130 {batchSize, inputDim, 1},
136 yDesc_.reset(
new detail::TensorDescriptors<T>(
139 {batchSize, hiddenSize * numDirections, 1},
141 {numDirections * hiddenSize, 1, 1}));
144 output->Resize(std::vector<int>{seqLength, batchSize, outputDim});
150 const std::array<int, 3> dim{
151 numLayers * numDirections, batchSize, hiddenSize};
152 const std::array<int, 3> stride{batchSize * hiddenSize, hiddenSize, 1};
153 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
154 hxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
160 hiddenOutput->Resize(
161 std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
166 std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
173 CUDNN_ENFORCE(cudnnGetRNNParamsSize(
174 cudnn_wrapper_.inline_cudnn_handle(),
178 cudnnTypeWrapper<T>::type));
179 const std::array<int, 3> dims{
184 CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
185 wDesc_, cudnnTypeWrapper<T>::type, CUDNN_TENSOR_NCHW, 3, dims.data()));
190 CUDNN_ENFORCE(cudnnGetRNNWorkspaceSize(
191 cudnn_wrapper_.inline_cudnn_handle(),
199 template <
typename T>
200 bool RecurrentOp<T>::RunOnDevice() {
201 const int seqLength = Input(INPUT).dim32(0);
202 if (Input(INPUT).sizes() != cachedInputDims_) {
205 Output(DROPOUT_STATES),
207 Output(HIDDEN_OUTPUT),
208 Output(CELL_OUTPUT));
209 cachedInputDims_ = Input(INPUT).sizes().vec();
214 CUDNN_ENFORCE(cudnnGetRNNParamsSize(
215 cudnn_wrapper_.inline_cudnn_handle(),
219 cudnnTypeWrapper<T>::type));
220 CAFFE_ENFORCE_EQ(Input(WEIGHT).nbytes(), weightsSize);
223 CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
224 cudnn_wrapper_.inline_cudnn_handle(),
230 ->Resize(std::vector<int>{
static_cast<int>(
231 reserveNbytes_ / 4)});
232 Output(RNN_SCRATCH)->template mutable_data<T>();
234 auto InputData = [
this](
int i) {
return this->Input(i).template data<T>(); };
235 auto OutputData = [
this](
int i) {
236 return this->Output(i)->template mutable_data<T>();
239 if (OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
240 cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
241 CUDNN_ENFORCE(cudnnRNNForwardInference(
242 state->cudnn_handle(),
248 InputData(HIDDEN_INPUT),
250 InputData(CELL_INPUT),
256 OutputData(HIDDEN_OUTPUT),
258 OutputData(CELL_OUTPUT),
259 state->workspace().get(cudnnWsNbytes_),
263 cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
264 CUDNN_ENFORCE(cudnnRNNForwardTraining(
265 state->cudnn_handle(),
271 InputData(HIDDEN_INPUT),
273 InputData(CELL_INPUT),
279 OutputData(HIDDEN_OUTPUT),
281 OutputData(CELL_OUTPUT),
282 state->workspace().get(cudnnWsNbytes_),
284 OutputData(RNN_SCRATCH),
292 template <
typename T>
293 bool RecurrentGradientOp<T>::RunOnDevice() {
294 const int seqLength = Input(INPUT).dim32(0);
295 if (Input(INPUT).sizes() != cachedInputDims_) {
296 initialize(Input(INPUT), Output(DROPOUT_STATES));
297 cachedInputDims_ = Input(INPUT).sizes().vec();
299 CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
300 cudnn_wrapper_.inline_cudnn_handle(),
305 CAFFE_ENFORCE_EQ(reserveNbytes_, Input(RNN_SCRATCH).nbytes());
306 Output(GRAD_INPUT)->ResizeLike(Input(INPUT));
307 Output(GRAD_HIDDEN_INPUT)->ResizeLike(Input(HIDDEN_INPUT));
308 Output(GRAD_CELL_INPUT)->ResizeLike(Input(CELL_INPUT));
310 Output(GRAD_WEIGHT)->ResizeLike(Input(WEIGHT));
311 math::Set<T, CUDAContext>(
312 Output(GRAD_WEIGHT)->numel(),
314 Output(GRAD_WEIGHT)->template mutable_data<T>(),
317 #if CUDNN_VERSION_MIN(6,0,0) 318 auto * reserve = Output(RNN_SCRATCH_OUT)->template mutable_data<T>();
320 const auto * reserve = Output(RNN_SCRATCH_OUT)->template data<T>();
322 auto InputData = [
this](
int i) {
return this->Input(i).template data<T>(); };
323 auto OutputData = [
this](
int i) {
324 return this->Output(i)->template mutable_data<T>();
327 cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
328 CUDNN_ENFORCE(cudnnRNNBackwardData(
329 state->cudnn_handle(),
335 InputData(GRAD_OUTPUT),
345 InputData(HIDDEN_INPUT),
347 InputData(CELL_INPUT),
349 OutputData(GRAD_INPUT),
351 OutputData(GRAD_HIDDEN_INPUT),
353 OutputData(GRAD_CELL_INPUT),
354 state->workspace().get(cudnnWsNbytes_),
358 CUDNN_ENFORCE(cudnnRNNBackwardWeights(
359 state->cudnn_handle(),
365 InputData(HIDDEN_INPUT),
368 state->workspace().get(cudnnWsNbytes_),
380 template <
typename T, RecurrentParamOpMode mode>
381 bool RecurrentParamAccessOp<T, mode>::RunOnDevice() {
382 initialize(Input(0));
384 if (mode == SET_PARAM) {
386 CUDNN_ENFORCE(cudnnGetRNNParamsSize(
387 cudnn_wrapper_.inline_cudnn_handle(),
391 cudnnTypeWrapper<T>::type));
394 paramsSize / 4, Input(1).numel(),
"Incorrect weight initialization");
397 int layer = OperatorBase::GetSingleArgument<int>(
"layer", 0);
398 std::string param_type =
399 OperatorBase::GetSingleArgument<string>(
"param_type",
"");
400 std::string input_type =
401 OperatorBase::GetSingleArgument<string>(
"input_type",
"");
404 std::map<string, int> weight_constants = {{
"input_gate_w", 0},
405 {
"forget_gate_w", 1},
407 {
"output_gate_w", 3}};
408 std::map<string, int> bias_constants = {{
"input_gate_b", 0},
409 {
"forget_gate_b", 1},
411 {
"output_gate_b", 3}};
412 if (bias_constants.find(param_type) != bias_constants.end()) {
413 int param_id = bias_constants[param_type] + 4 * (input_type ==
"recurrent");
415 cudnnFilterDescriptor_t biasDesc;
416 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&biasDesc));
419 CUDNN_ENFORCE(cudnnGetRNNLinLayerBiasParams(
420 cudnn_wrapper_.inline_cudnn_handle(),
425 Input(1).template data<T>(),
430 std::vector<int> biasDims(3);
432 cudnnTensorFormat_t tf;
434 CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
435 biasDesc, 3, &dt, &tf, &numBiasDims, biasDims.data()));
436 CAFFE_ENFORCE_EQ(numBiasDims, 3);
438 if (mode == SET_PARAM) {
440 biasDims[0] * biasDims[1] * biasDims[2], Input(2).numel());
441 this->context_.template CopySameDevice<T>(
442 biasDims[0] * biasDims[1] * biasDims[2],
443 Input(2).template data<T>(),
444 static_cast<T*>(bias));
446 Output(0)->Resize(biasDims);
447 this->context_.template CopySameDevice<T>(
448 biasDims[0] * biasDims[1] * biasDims[2],
449 static_cast<T*
>(bias),
450 Output(0)->template mutable_data<T>());
452 }
else if (weight_constants.find(param_type) != weight_constants.end()) {
454 weight_constants[param_type] + 4 * (input_type ==
"recurrent");
455 cudnnFilterDescriptor_t matrixParamDesc;
456 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&matrixParamDesc));
458 CUDNN_ENFORCE(cudnnGetRNNLinLayerMatrixParams(
459 cudnn_wrapper_.inline_cudnn_handle(),
464 Input(1).template data<T>(),
469 std::vector<int> matDims(3);
471 cudnnTensorFormat_t tf;
473 CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
474 matrixParamDesc, 3, &dt, &tf, &numDims, matDims.data()));
475 CAFFE_ENFORCE_EQ(numDims, 3);
476 if (mode == SET_PARAM) {
477 CAFFE_ENFORCE_EQ(matDims[0] * matDims[1] * matDims[2], Input(2).numel());
478 this->context_.template CopySameDevice<T>(
479 matDims[0] * matDims[1] * matDims[2],
480 Input(2).template data<T>(),
481 static_cast<T*>(pmatrix));
483 Output(0)->Resize(matDims);
484 this->context_.template CopySameDevice<T>(
485 matDims[0] * matDims[1] * matDims[2],
486 static_cast<T*
>(pmatrix),
487 Output(0)->template mutable_data<T>());
490 CAFFE_ENFORCE(
false,
"Unknown param type:", param_type);
496 REGISTER_CUDNN_OPERATOR(Recurrent, RecurrentOp<float>);
497 OPERATOR_SCHEMA(Recurrent).NumInputs(4).NumOutputs(5).SetDoc(R
"DOC( 499 Recurrent wraps the CuDNN R5 RNN implementation. See the CuDNN R5 500 documentation for more information. 502 In general, the implementation takes an input (TxNxD) tensor, the 503 hidden state input (NxD), the cell input (NxD), and a weight tensor 504 (effectively an opaque blob, where the size and layout is dictated by 507 The outputs are the output (again, TxNxD), the final hidden/cell 508 states (NxD). These can be reset (at sequence boundaries across 509 minibatches) by multiplying by zero. 511 The CuDNN arguments (hidden_size, bidirectional, num_layers, rnn_mode, 512 input_mode) are passed directly through to CuDNN. 515 REGISTER_CUDNN_OPERATOR(RecurrentGradient, RecurrentGradientOp<float>); 516 OPERATOR_SCHEMA(RecurrentGradient) 519 .AllowInplace({{4, 5}}); 521 REGISTER_CUDNN_OPERATOR( 523 RecurrentParamAccessOp<float, SET_PARAM>); 524 OPERATOR_SCHEMA(RecurrentParamSet) 527 .EnforceInplace({{1, 0}}) 528 .SetDoc("Set individual parameters of a recurrent net.")
529 .Arg(
"param_type", R
"DOC(Type of param to be set: 530 "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w" 531 "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b" 533 .Arg("input_type",
"'recurrent' or 'input'")
534 .Arg(
"layer",
"layer index (starting from 0)")
535 .Input(0,
"input", R
"DOC(Input blob. Needed for inferring the shapes. 536 A dummy tensor matching the input shape is ok.)DOC") 537 .Input(1, "all_params",
"Blob holding all the parameters")
538 .Input(2,
"param",
"Values for the specified parameter")
542 "Blob holding all the parameters (same as input(1))");
544 REGISTER_CUDNN_OPERATOR(
546 RecurrentParamAccessOp<float, GET_PARAM>);
547 OPERATOR_SCHEMA(RecurrentParamGet)
550 .SetDoc(
"Retrieve individual parameters of a recurrent net op.")
551 .Arg(
"param_type", R
"DOC(Type of param to be set: 552 "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w" 553 "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b" 555 .Arg("input_type",
"'recurrent' or 'input'")
556 .Arg(
"layer",
"layer index (starting from 0)")
557 .Input(0,
"input", R
"DOC(Input blob. Needed for inferring the shapes. 558 A dummy tensor matching the input shape is ok.)DOC") 559 .Input(1, "all_params",
"Blob holding all the parameters")
560 .Output(0,
"param",
"Blob holding the requested values");
563 using GradientMakerBase::GradientMakerBase;
564 vector<OperatorDef> GetGradientDefs()
override {
565 return SingleGradientDef(
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...