Caffe2 - C++ API
A deep learning, cross platform ML framework
minmax_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_
18 #define CAFFE2_OPERATORS_MINMAX_OPS_H_
19 
20 #include "caffe2/core/common_omp.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/logging.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/core/types.h"
25 #include "caffe2/utils/math.h"
26 
27 namespace caffe2 {
28 
29 template <typename T, class Context>
30 class MaxMinOpBase : public Operator<Context> {
31  public:
32  USE_OPERATOR_CONTEXT_FUNCTIONS;
33  USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
34 
35  bool RunOnDevice() override {
36  auto& input0 = Input(0);
37  auto* output = Output(0);
38 
39  output->ResizeLike(input0);
40  output->CopyFrom(input0, &context_);
41 
42  if (InputSize() == 1) {
43  return true;
44  }
45 
46  // Dimension checking
47  for (int i = 1; i < InputSize(); ++i) {
48  CAFFE_ENFORCE_EQ(
49  output->dims(),
50  Input(i).dims(),
51  "Description: Input #",
52  i,
53  ", input dimension:",
54  Input(i).dims(),
55  " should match output dimension: ",
56  output->dims());
57  }
58 
59  return this->Compute();
60  }
61 
62  virtual bool Compute() = 0;
63 };
64 
65 template <typename T, class Context>
66 class MaxOp : public MaxMinOpBase<T, Context> {
67  public:
68  USE_OPERATOR_CONTEXT_FUNCTIONS;
69  MaxOp(const OperatorDef& operator_def, Workspace* ws)
70  : MaxMinOpBase<T, Context>(operator_def, ws) {}
71  virtual ~MaxOp() noexcept {}
72  bool Compute() override;
73 };
74 
75 template <typename T, class Context>
76 class SelectGradientOpBase : public Operator<Context> {
77  public:
78  USE_OPERATOR_CONTEXT_FUNCTIONS;
79  USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
80 
81  bool RunOnDevice() override;
82 };
83 
84 template <typename T, class Context>
85 class MaxGradientOp : public SelectGradientOpBase<T, Context> {
86  public:
87  MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
88  : SelectGradientOpBase<T, Context>(operator_def, ws) {}
89  virtual ~MaxGradientOp() noexcept {}
90 };
91 
92 template <typename T, class Context>
93 class MinOp : public MaxMinOpBase<T, Context> {
94  public:
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96  MinOp(const OperatorDef& operator_def, Workspace* ws)
97  : MaxMinOpBase<T, Context>(operator_def, ws) {}
98  virtual ~MinOp() noexcept {}
99  bool Compute() override;
100 };
101 
102 template <typename T, class Context>
103 class MinGradientOp : public SelectGradientOpBase<T, Context> {
104  public:
105  MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
106  : SelectGradientOpBase<T, Context>(operator_def, ws) {}
107  virtual ~MinGradientOp() noexcept {}
108 };
109 
110 } // namespace caffe2
111 
112 #endif // CAFFE2_OPERATORS_MINMAX_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.