Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_gpu.cc
1 #include "caffe2/core/common_gpu.h"
2 #include "caffe2/core/context_gpu.h"
3 #include "caffe2/operators/fully_connected_op.h"
4 
5 namespace caffe2 {
6 
7 namespace {
8 
9 template <class FullyConnectedOp>
10 bool RunFullyConnectedOpOnCUDADevice(
11  const bool float16_compute,
12  FullyConnectedOp* op) {
13  if (op->Input(0).template IsType<float>()) {
14  return op->template DoRunWithType<
15  float, // X
16  float, // W
17  float, // B
18  float, // Y
19  float>(); // Math
20  } else if (op->Input(0).template IsType<at::Half>()) {
21  if (float16_compute) {
22  const cudaDeviceProp& prop = GetDeviceProperty(0);
23  if (prop.major >= kFp16CUDADevicePropMajor) {
24  return op->template DoRunWithType<
25  at::Half, // X
26  at::Half, // W
27  at::Half, // B
28  at::Half, // Y
29  at::Half>(); // Math
30  } else {
31  LOG(INFO) << "CUDA Device does not support FP16 computation, "
32  "falling back to FP32.";
33  return op->template DoRunWithType<
34  at::Half, // X
35  at::Half, // W
36  at::Half, // B
37  at::Half, // Y
38  float>(); // Math
39  }
40  } else {
41  return op->template DoRunWithType<
42  at::Half, // X
43  at::Half, // W
44  at::Half, // B
45  at::Half, // Y
46  float>(); // Math
47  }
48  } else {
49  CAFFE_THROW("Unsupported type");
50  }
51  return false;
52 }
53 
54 template <class FullyConnectedGradientOp>
55 bool RunFullyConnectedGradientOpOnCUDADevice(
56  const bool float16_compute,
57  FullyConnectedGradientOp* op) {
58  if (op->Input(0).template IsType<float>()) {
59  return op->template DoRunWithType<
60  float, // X
61  float, // W
62  float, // dY
63  float, // B
64  float, // dX
65  float, // dW
66  float, // dB
67  float>(); // Math
68  } else if (op->Input(0).template IsType<at::Half>()) {
69  if (float16_compute) {
70  const cudaDeviceProp& prop = GetDeviceProperty(0);
71  if (prop.major >= kFp16CUDADevicePropMajor) {
72  return op->template DoRunWithType<
73  at::Half, // X
74  at::Half, // W
75  at::Half, // dY
76  at::Half, // B
77  at::Half, // dX
78  at::Half, // dW
79  at::Half, // dB
80  at::Half>(); // Math
81  } else {
82  LOG(INFO) << "CUDA Device does not support FP16 computation, "
83  "falling back to FP32.";
84  return op->template DoRunWithType<
85  at::Half, // X
86  at::Half, // W
87  at::Half, // dY
88  at::Half, // B
89  at::Half, // dX
90  at::Half, // dW
91  at::Half, // dB
92  float>(); // Math
93  }
94  } else {
95  return op->template DoRunWithType<
96  at::Half, // X
97  at::Half, // W
98  at::Half, // dY
99  at::Half, // B
100  at::Half, // dX
101  at::Half, // dW
102  at::Half, // dB
103  float>(); // Math
104  }
105  } else {
106  CAFFE_THROW("Unsupported type");
107  }
108  return false;
109 }
110 
111 } // namespace
112 
113 // The RunFullyConnectedOpOnCUDADevice Function will use the pointer of current
114 // op and the DoRunWithType will make sure to run the correct things.
115 template <>
116 bool FullyConnectedOp<CUDAContext>::RunOnDevice() {
117  return RunFullyConnectedOpOnCUDADevice(float16_compute_, this);
118 }
119 
120 template <>
121 bool FullyConnectedOp<
122  CUDAContext,
123  DefaultEngine,
124  false /* don't transpose weight */>::RunOnDevice() {
125  return RunFullyConnectedOpOnCUDADevice(float16_compute_, this);
126 }
127 
128 template <>
129 bool FullyConnectedGradientOp<CUDAContext>::RunOnDevice() {
130  return RunFullyConnectedGradientOpOnCUDADevice(float16_compute_, this);
131 }
132 
133 template <>
134 bool FullyConnectedGradientOp<
135  CUDAContext,
136  DefaultEngine,
137  false /* don't transpose weight */>::RunOnDevice() {
138  return RunFullyConnectedGradientOpOnCUDADevice(float16_compute_, this);
139 }
140 
141 #if CUDA_VERSION >= 9000
142 
143 // Require these to be defined otherwise TensorCore FC ops will end
144 // up calling the default FC implementation which doesn't have
145 // fp16 support...
146 
147 template <>
148 bool FullyConnectedOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {
149  return RunFullyConnectedOpOnCUDADevice(false /* float16_compute */, this);
150 }
151 
152 template <>
153 bool FullyConnectedOp<
154  CUDAContext,
155  TensorCoreEngine,
156  false /* don't transpose weight */>::RunOnDevice() {
157  return RunFullyConnectedOpOnCUDADevice(false /* float16_compute */, this);
158 }
159 
160 template <>
161 bool FullyConnectedGradientOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {
162  return RunFullyConnectedGradientOpOnCUDADevice(
163  false /* float16_compute */, this);
164 }
165 
166 template <>
167 bool FullyConnectedGradientOp<
168  CUDAContext,
169  TensorCoreEngine,
170  false /* don't transpose weight */>::RunOnDevice() {
171  return RunFullyConnectedGradientOpOnCUDADevice(
172  false /* float16_compute */, this);
173 }
174 
175 #endif
176 
177 REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<CUDAContext>);
178 REGISTER_CUDA_OPERATOR(FCGradient, FullyConnectedGradientOp<CUDAContext>);
179 
180 REGISTER_CUDA_OPERATOR(
181  FCTransposed,
182  FullyConnectedOp<
183  CUDAContext,
184  DefaultEngine,
185  false /* don't transpose weight */>);
186 REGISTER_CUDA_OPERATOR(
187  FCTransposedGradient,
188  FullyConnectedGradientOp<
189  CUDAContext,
190  DefaultEngine,
191  false /* don't transpose weight */>);
192 
193 #if CUDA_VERSION >= 9000
194 REGISTER_CUDA_OPERATOR_WITH_ENGINE(
195  FC,
196  TENSORCORE,
197  FullyConnectedOp<CUDAContext, TensorCoreEngine>);
198 REGISTER_CUDA_OPERATOR_WITH_ENGINE(
199  FCGradient,
200  TENSORCORE,
201  FullyConnectedGradientOp<CUDAContext, TensorCoreEngine>);
202 
203 REGISTER_CUDA_OPERATOR_WITH_ENGINE(
204  FCTransposed,
205  TENSORCORE,
206  FullyConnectedOp<
207  CUDAContext,
208  TensorCoreEngine,
209  false /* don't transpose weight */>);
210 REGISTER_CUDA_OPERATOR_WITH_ENGINE(
211  FCTransposedGradient,
212  TENSORCORE,
213  FullyConnectedGradientOp<
214  CUDAContext,
215  TensorCoreEngine,
216  false /* don't transpose weight */>);
217 #endif
218 
219 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: OpClasses.h:566
const cudaDeviceProp & GetDeviceProperty(const int deviceid)
Gets the device property for the given device.
Definition: common_gpu.cc:139