Caffe2 - C++ API
A deep learning, cross platform ML framework
listwise_l2r_op.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #pragma once
4 
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class LambdaRankNdcgOp final : public Operator<Context> {
14  public:
15  LambdaRankNdcgOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws) {}
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18  bool RunOnDevice() override;
19 
20  private:
21  INPUT_TAGS(PRED, REL);
22  OUTPUT_TAGS(LOSS, DPRED);
23 
24  void ResizeInvLogITensor(int);
25  void ComputeDiscounts(int*, int);
26  Tensor<Context> gain_;
27  Tensor<Context> discount_;
28  Tensor<Context> rank_idx_;
29  Tensor<Context> ideal_idx_;
30  Tensor<Context> lambda_;
31  Tensor<Context> inv_log_i_;
32 };
33 
34 template <typename T, class Context>
35 class LambdaRankNdcgGradientOp final : public Operator<Context> {
36  public:
37  USE_SIMPLE_CTOR_DTOR(LambdaRankNdcgGradientOp);
38  USE_OPERATOR_CONTEXT_FUNCTIONS;
39  bool RunOnDevice() override;
40 
41  private:
42  INPUT_TAGS(Y, DY_CACHE, DLOSS);
43  OUTPUT_TAGS(DY);
44 };
45 
46 } // namespace caffe2
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.