Caffe2 - C++ API
A deep learning, cross platform ML framework
listwise_l2r_op.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #pragma once
4 
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class LambdaRankNdcgOp final : public Operator<Context> {
14  public:
15  template <class... Args>
16  explicit LambdaRankNdcgOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  use_ndcg_as_loss_(
19  this->template GetSingleArgument<bool>("use_ndcg_as_loss", false)),
20  use_exp_gain_(
21  this->template GetSingleArgument<bool>("use_exp_gain", true)) {}
22  USE_OPERATOR_CONTEXT_FUNCTIONS;
23  bool RunOnDevice() override;
24 
25  private:
26  INPUT_TAGS(PRED, REL, SESSION_LENS);
27  OUTPUT_TAGS(LOSS, DPRED);
28 
29  void ResizeInvLogITensor(int);
30  void ComputeDiscounts(int*, int);
31  float LambdaRankNdcgSession(
32  int start_index,
33  int end_index,
34  const Tensor& y,
35  const Tensor& r,
36  Tensor** dy);
37  bool use_ndcg_as_loss_;
38  bool use_exp_gain_;
39  Tensor gain_;
40  Tensor discount_;
41  Tensor rank_idx_;
42  Tensor ideal_idx_;
43  Tensor lambda_;
44  Tensor inv_log_i_;
45 };
46 
47 template <typename T, class Context>
48 class LambdaRankNdcgGradientOp final : public Operator<Context> {
49  public:
50  USE_SIMPLE_CTOR_DTOR(LambdaRankNdcgGradientOp);
51  USE_OPERATOR_CONTEXT_FUNCTIONS;
52  bool RunOnDevice() override;
53 
54  private:
55  INPUT_TAGS(Y, SESSION_LENS, DY_CACHE, DLOSS);
56  OUTPUT_TAGS(DY);
57 };
58 
59 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13