1 #include "softmax_with_loss_op.h"     2 #include "softmax_shared.h"     6 REGISTER_CPU_OPERATOR(SoftmaxWithLoss, SoftmaxWithLossOp<float, CPUContext>);
     8     SoftmaxWithLossGradient,
     9     SoftmaxWithLossGradientOp<float, CPUContext>);
    12 OPERATOR_SCHEMA(SoftmaxWithLoss)
    15     .TensorInferenceFunction(
    16         [](
const OperatorDef& def, 
const vector<TensorShape>& in) {
    17           ArgumentHelper helper(def);
    18           auto axis = helper.GetSingleArgument<int32_t>(
"axis", 1);
    20           vector<TensorShape> out(2);
    24           const auto canonical_axis =
    25               canonical_axis_index_(axis, logits.dims().size());
    26           const int batch_size =
    27               size_to_dim_(canonical_axis, GetDimsVector(logits));
    28           const int num_classes =
    31           out[0].set_data_type(logits.data_type());
    32           out[0].add_dims(batch_size);
    33           out[0].add_dims(num_classes);
    38 Combined Softmax and Cross-Entropy loss operator. The operator first computes the softmax normalized values for each layer in the batch of the given input, then computes cross-entropy loss. This operator is numerically more stable than separate `Softmax` and `CrossEntropy` ops. The inputs are a 2-D tensor `logits` of size (batch_size x input_feature_dimensions), which represents the unscaled log probabilities, and a 1-dimensional integer `labels` tensor for ground truth. An optional third input blob (`weight_tensor`) can be used to weight the samples for the loss, which is useful if the training set is unbalanced. This operator outputs a `softmax` tensor which contains the probability for each label for each example (same shape is `logits` input), and a scalar `loss` value, which is the averaged cross-entropy loss between the softmax probabilities and the ground truth values. Use parameter `label_prob`=1 to enable inputting labels as a probability distribution.    40 Softmax cross-entropy loss function:    42 $$loss(x, class) = -\log{\biggl(\frac{\exp(x[class])}{\sum_{j} \exp(x[j])}\biggr)} = -x[class] + \log{\biggl(\sum_{j} \exp(x[j])\biggr)}$$    44 or if the `weight_tensor` has been passed:    46 $$loss(x, class) = weight[class]\biggl(-x[class] + \log{\biggl(\sum_{j} \exp(x[j])\biggr)}\biggr)$$    48 The `logits` input does not need to explicitly be a 2D vector; rather, it will be coerced into one. For an arbitrary n-dimensional tensor `X` in $[a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}]$, where k is the `axis` provided, then `X` will be coerced into a 2-dimensional tensor with dimensions $[(a_0 * ... * a_{k-1}), (a_k * ... * a_{n-1})]$. For the default case where `axis`=1, the `X` tensor will be coerced into a 2D tensor of dimensions $[a_0, (a_1 * ... * a_{n-1})]$, where $a_0$ is often the batch size. In this situation, we must have $a_0 = N$ and $a_1 * ... * a_{n-1} = D$. Each of these dimensions must be matched correctly, or else the operator will throw errors.    52 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/softmax_with_loss_op.cc    57 <summary> <b>Example</b> </summary>    63 workspace.ResetWorkspace()    65 op = core.CreateOperator(    68     ["softmax", "avgloss"]    71 workspace.FeedBlob("logits", np.random.randn(1, 5).astype(np.float32))    72 workspace.FeedBlob("labels", np.asarray([4]).astype(np.int32))    73 print("logits:", workspace.FetchBlob("logits"))    74 print("labels:", workspace.FetchBlob("labels"))    75 workspace.RunOperatorOnce(op)    76 print("softmax:", workspace.FetchBlob("softmax"))    77 print("avgloss:", workspace.FetchBlob("avgloss"))    85 logits: [[-0.3429451  -0.80375195  0.23104447  1.4569176  -0.5268362 ]]    87 softmax: [[0.09721052 0.0613179  0.17258129 0.58800864 0.0808817 ]]    96 <summary> <b>Example 2</b> </summary>   102 workspace.ResetWorkspace()   104 op = core.CreateOperator(   106     ["logits", "labels"],   107     ["softmax", "avgloss"],   111 workspace.FeedBlob("logits", np.asarray([[.1, .4, .7, 1.5, .2]]).astype(np.float32))   112 workspace.FeedBlob("labels", np.asarray([4]).astype(np.int32))   113 print("logits:", workspace.FetchBlob("logits"))   114 print("labels:", workspace.FetchBlob("labels"))   115 workspace.RunOperatorOnce(op)   116 print("softmax:", workspace.FetchBlob("softmax"))   117 print("avgloss:", workspace.FetchBlob("avgloss"))   125 logits: [[0.1 0.4 0.7 1.5 0.2]]   127 softmax: [[0.10715417 0.144643   0.19524762 0.4345316  0.11842369]]   135     .Arg("label_prob",
"*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
   136     .Arg(
"axis",
"*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
   137     .Arg(
"scale",
"*(type: float)* Average loss output scaling factor (must be >= 0).")
   138     .Arg(
"order",
"*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
   139     .Input(0, 
"logits", 
"*(type: Tensor`<float>`)* Input tensor.")
   140     .Input(1, 
"labels", 
"*(type: Tensor`<float>`)* Ground truth label tensor.")
   144         "*(type: Tensor`<float>`)* [OPTIONAL] Blob used to weight the samples for the loss.")
   145     .Output(0, 
"softmax", 
"*(type: Tensor`<float>`)* Softmax output tensor.")
   146     .Output(1, 
"loss", 
"*(type: float)* Averaged cross-entropy loss output.");
   149 OPERATOR_SCHEMA(SoftmaxWithLossGradient).NumOutputs(1);
   151 #define DONT_CARE (-1)   154 bool SoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
   158   const auto canonical_axis = X.canonical_axis_index(axis_);
   160   N = X.size_to_dim(canonical_axis); 
   161   D = X.size_from_dim(canonical_axis);
   163       Output(0, X.sizes(), at::dtype<float>()); 
   165   float* Pdata = P->template mutable_data<float>();
   166   const float* weights = (InputSize() > 2 ? Input(2).data<
float>() : 
nullptr);
   168   if (label_prob_mode_) {
   169     CAFFE_ENFORCE_GE(
T.dim(), 2);
   170     CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
   171     CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), D);
   173     if (
T.dim() == canonical_axis) {
   174       CAFFE_ENFORCE_EQ(
T.numel(), N);
   176       CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
   177       CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), 1);
   181   if (!sum_multiplier_.defined()) {
   182     sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
   183     math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<
float>(), &context_);
   184   } 
else if (sum_multiplier_.numel() != D) {
   185     sum_multiplier_.Resize(D);
   186     math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<
float>(), &context_);
   189   if (!losses_.defined()) {
   190     losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
   191   } 
else if (losses_.numel() != N) {
   195   if (!rowmax_.defined()) {
   196     rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
   197   } 
else if (rowmax_.numel() != N) {
   207       losses_.mutable_data<
float>(),
   208       sum_multiplier_.data<
float>(),
   210       rowmax_.mutable_data<
float>());
   213   float loss_sum = 0.0;
   214   float weight_sum = 0.0;
   215   if (!label_prob_mode_) {
   216     const int* label_data = 
T.data<
int>();
   217     const float* Xdata = X.data<
float>();
   219     for (
int i = 0; i < N; ++i) {
   221           label_data[i] < D && label_data[i] >= 0,
   222           "Label seems incorrect: label value larger than number of classes: ",
   226       float weight = weights ? weights[i] : 1.0;
   227       float l = -Pdata[i * D + label_data[i]] * weight;
   229       weight_sum += weight;
   231     math::Exp(N * D, Pdata, Pdata, &context_);
   233     const float* label_data = 
T.data<
float>();
   235     for (
int i = 0; i < N; ++i) {
   237       float total_prob = 0.0;
   238       float weight = weights ? weights[i] : 1.0;
   239       for (
int j = 0; j < D; ++j) {
   241             label_data[i * D + j] >= 0,
   242             "Label prob seems incorrect: label prob value must be nonnegative:",
   244             label_data[i * D + j]);
   245         l += -log(std::max(Pdata[i * D + j], 1e-20f)) * label_data[i * D + j] *
   247         total_prob += label_data[i * D + j];
   251           std::abs(total_prob - 1.) < 1e-5f,
   252           "Label prob seems incorrect: label prob values do not sum to 1.0: ",
   254           " vs 1.0 (+/- 1e-5)");
   255       weight_sum += weight;
   260       Output(1, vector<int64_t>(), at::dtype<float>()); 
   262   float* avg_loss_data = avg_loss->template mutable_data<float>();
   263   if (weight_sum != 0.0) {
   264     avg_loss_data[0] = loss_sum * scale_ / weight_sum;
   266     avg_loss_data[0] = 0.0;
   272 bool SoftmaxWithLossGradientOp<float, CPUContext>::RunOnDevice() {
   276   auto& P = Input(InputSize() - 2); 
   277   auto& d_avg_loss = Input(InputSize() - 1); 
   279   const float* weights = (InputSize() > 4 ? Input(2).data<
float>() : 
nullptr);
   281   const auto canonical_axis = X.canonical_axis_index(axis_);
   283   N = X.size_to_dim(canonical_axis); 
   284   D = X.size_from_dim(canonical_axis);
   285   auto* dX = Output(0, X.sizes(), at::dtype<float>());
   287   if (label_prob_mode_) {
   288     CAFFE_ENFORCE_GE(
T.dim(), 2);
   289     CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
   290     CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), D);
   292     if (
T.dim() == canonical_axis) {
   293       CAFFE_ENFORCE_EQ(
T.numel(), N);
   295       CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
   296       CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), 1);
   300   const float* Pdata = P.data<
float>();
   301   float* dX_data = dX->template mutable_data<float>();
   306   context_.CopyFromCPU<
float>(P.numel(), Pdata, dX_data);
   309   float total_weight = 0.0f;
   310   if (!label_prob_mode_) {
   311     const int* label_data = 
T.data<
int>();
   314       for (
int i = 0; i < N; ++i) {
   315         int idx = i * D + label_data[i];
   316         float weight = weights[i];
   317         dX_data[idx] = Pdata[idx] - 1.0;
   318         for (
int d = 0; d < D; d++) {
   320           dX_data[k] *= weight;
   323         total_weight += weight;
   326       for (
int i = 0; i < N; ++i) {
   327         int idx = i * D + label_data[i];
   328         dX_data[idx] = Pdata[idx] - 1.0f;
   333     const float* label_data = 
T.data<
float>();
   336       for (
int i = 0; i < N; ++i) {
   337         float weight = weights[i];
   338         for (
int j = 0; j < D; ++j) {
   340           dX_data[idx] = (Pdata[idx] - label_data[idx]) * weight;
   342         total_weight += weight;
   345       for (
int i = 0; i < N; ++i) {
   346         for (
int j = 0; j < D; ++j) {
   348           dX_data[idx] = Pdata[idx] - label_data[idx];
   356   if (total_weight > 0) {
   357     math::Scale<float, float, CPUContext>(
   359         scale_ / total_weight * d_avg_loss.data<
float>()[0],
   368 class GetSoftmaxWithLossGradient : 
public GradientMakerBase {
   369   using GradientMakerBase::GradientMakerBase;
   370   vector<OperatorDef> GetGradientDefs()
 override {
   371     vector<string> blob_names{
   372         {I(0), I(1), O(0), GO(1)},
   376     if (def_.input_size() == 3) {
   377       blob_names.emplace(blob_names.begin() + 2, I(2));
   379     return SingleGradientDef(
   380         "SoftmaxWithLossGradient", 
"", blob_names, vector<string>{GI(0)});
   384 REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient);
 
int64_t size_from_dim_(int k, IntArrayRef dims)
Return product of all dimensions starting from k. 
 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...