Caffe2 - C++ API
A deep learning, cross platform ML framework
rank_loss_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/context.h"
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 // support multiple batches of sessions
27 template <typename T, class Context>
28 class PairWiseLossOp final : public Operator<Context> {
29  public:
30  USE_SIMPLE_CTOR_DTOR(PairWiseLossOp);
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  bool RunOnDevice() override;
33 
34  private:
35  INPUT_TAGS(XVALUE, LABEL, LENGTHS);
36  OUTPUT_TAGS(YVALUE);
37 };
38 
39 template <typename T, class Context>
40 class PairWiseLossGradientOp final : public Operator<Context> {
41  public:
42  USE_SIMPLE_CTOR_DTOR(PairWiseLossGradientOp);
43  USE_OPERATOR_CONTEXT_FUNCTIONS;
44  bool RunOnDevice() override;
45 
46  private:
47  INPUT_TAGS(XVALUE, LABEL, DYVALUE, LENGTHS);
48  OUTPUT_TAGS(DXVALUE);
49 };
50 
51 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.