Caffe2 - C++ API
A deep learning, cross platform ML framework
string_ops.cc
1 
17 #include "caffe2/operators/string_ops.h"
18 #include "caffe2/core/operator.h"
19 
20 namespace caffe2 {
21 
22 template <>
23 template <typename T>
24 bool StringJoinOp<CPUContext>::DoRunWithType() {
25  const auto& input = Input(0);
26  auto* output = Output(0);
27  CAFFE_ENFORCE_GT(input.size(), 0);
28  CAFFE_ENFORCE_LE(input.ndim(), 2, "Only 1-D and 2-D tensors are supported");
29 
30  const auto* inputData = input.data<T>();
31  int rowSize = (input.ndim() == 2) ? input.dim(1) : 1;
32  if (this->axis_ == 0) {
33  output->Resize(input.dim(0));
34  auto* outputData = output->mutable_data<std::string>();
35 
36  int offset = 0;
37  for (int i = 0; i < input.dim(0); ++i) {
38  std::stringstream stream;
39  std::copy(
40  inputData + offset,
41  inputData + offset + rowSize,
42  std::ostream_iterator<T>(stream, delimiter_.c_str()));
43  outputData[i] = stream.str();
44  offset += rowSize;
45  }
46  } else if (this->axis_ == 1) {
47  output->Resize(input.dim(1));
48  auto* outputData = output->mutable_data<std::string>();
49 
50  for (int j = 0; j < input.dim(1); ++j) {
51  std::stringstream stream;
52  for (int i = 0; i < input.dim(0); ++i) {
53  stream << inputData[i * rowSize + j] << delimiter_;
54  }
55  outputData[j] = stream.str();
56  }
57  } else {
58  CAFFE_ENFORCE(false, "Not supported");
59  }
60 
61  return true;
62 }
63 
64 namespace {
65 
66 struct StartsWith {
67  explicit StartsWith(OperatorBase& op)
68  : prefix_(op.GetSingleArgument<std::string>("prefix", "")) {}
69  bool operator()(const std::string& str) {
70  return std::mismatch(prefix_.begin(), prefix_.end(), str.begin()).first ==
71  prefix_.end();
72  }
73 
74  private:
75  std::string prefix_;
76 };
77 
78 struct EndsWith {
79  explicit EndsWith(OperatorBase& op)
80  : suffix_(op.GetSingleArgument<std::string>("suffix", "")) {}
81  bool operator()(const std::string& str) {
82  return std::mismatch(suffix_.rbegin(), suffix_.rend(), str.rbegin())
83  .first == suffix_.rend();
84  }
85 
86  private:
87  std::string suffix_;
88 };
89 
90 struct Prefix {
91  explicit Prefix(OperatorBase& op)
92  : length_(op.GetSingleArgument<int>("length", 3)) {}
93  std::string operator()(const std::string& str) {
94  return std::string(str.begin(), std::min(str.end(), str.begin() + length_));
95  }
96 
97  private:
98  int length_;
99 };
100 
101 struct Suffix {
102  explicit Suffix(OperatorBase& op)
103  : length_(op.GetSingleArgument<int>("length", 3)) {}
104  std::string operator()(const std::string& str) {
105  return std::string(std::max(str.begin(), str.end() - length_), str.end());
106  }
107 
108  private:
109  int length_;
110 };
111 
112 template <typename ScalarFunctor, typename TypeMap = FixedType<std::string>>
113 using StringElementwiseOp = UnaryElementwiseWithArgsOp<
114  TensorTypes<std::string>,
115  CPUContext,
116  ForEach<ScalarFunctor>,
117  TypeMap>;
118 
119 REGISTER_CPU_OPERATOR(StringPrefix, StringElementwiseOp<Prefix>);
120 REGISTER_CPU_OPERATOR(StringSuffix, StringElementwiseOp<Suffix>);
121 REGISTER_CPU_OPERATOR(
122  StringStartsWith,
123  StringElementwiseOp<StartsWith, FixedType<bool>>);
124 REGISTER_CPU_OPERATOR(
125  StringEndsWith,
126  StringElementwiseOp<EndsWith, FixedType<bool>>);
127 REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp<CPUContext>);
128 
129 OPERATOR_SCHEMA(StringPrefix)
130  .NumInputs(1)
131  .NumOutputs(1)
132  .SetDoc(R"DOC(
133 Computes the element-wise string prefix of the string tensor.
134 Input strings that are shorter than prefix length will be returned unchanged.
135 NOTE: Prefix is computed on number of bytes, which may lead to wrong behavior
136 and potentially invalid strings for variable-length encodings such as utf-8.
137 )DOC")
138  .Arg("length", "Maximum size of the prefix, in bytes.")
139  .Input(0, "strings", "Tensor of std::string.")
140  .Output(
141  0,
142  "prefixes",
143  "Tensor of std::string containing prefixes for each input.");
144 
145 OPERATOR_SCHEMA(StringSuffix)
146  .NumInputs(1)
147  .NumOutputs(1)
148  .SetDoc(R"DOC(
149 Computes the element-wise string suffix of the string tensor.
150 Input strings that are shorter than suffix length will be returned unchanged.
151 NOTE: Prefix is computed on number of bytes, which may lead to wrong behavior
152 and potentially invalid strings for variable-length encodings such as utf-8.
153 )DOC")
154  .Input(0, "strings", "Tensor of std::string.")
155  .Output(
156  0,
157  "suffixes",
158  "Tensor of std::string containing suffixes for each output.")
159  .Arg("length", "Maximum size of the suffix, in bytes.");
160 
161 OPERATOR_SCHEMA(StringStartsWith)
162  .NumInputs(1)
163  .NumOutputs(1)
164  .SetDoc(R"DOC(
165 Performs the starts-with check on each string in the input tensor.
166 Returns tensor of boolean of the same dimension of input.
167 )DOC")
168  .Arg("prefix", "The prefix to check input strings against.")
169  .Input(0, "strings", "Tensor of std::string.")
170  .Output(0, "bools", "Tensor of bools of same shape as input.");
171 
172 OPERATOR_SCHEMA(StringEndsWith)
173  .NumInputs(1)
174  .NumOutputs(1)
175  .SetDoc(R"DOC(
176 Performs the ends-with check on each string in the input tensor.
177 Returns tensor of boolean of the same dimension of input.
178 )DOC")
179  .Arg("suffix", "The suffix to check input strings against.")
180  .Input(0, "strings", "Tensor of std::string.")
181  .Output(0, "bools", "Tensor of bools of same shape as input.");
182 
183 OPERATOR_SCHEMA(StringJoin)
184  .NumInputs(1)
185  .NumOutputs(1)
186  .SetDoc(R"DOC(
187 Takes a 1-D or a 2-D tensor as input and joins elements in each row with the
188 provided delimiter. Output is a 1-D tensor of size equal to the first dimension
189 of the input. Each element in the output tensor is a string of concatenated
190 elements corresponding to each row in the input tensor. For 1-D input, each
191 element is treated as a row.
192 )DOC")
193  .Arg("delimiter", "Delimiter for join (Default: \",\").")
194  .Arg("axis", "Axis for the join (either 0 or 1)")
195  .Input(0, "input", "1-D or 2-D tensor")
196  .Output(
197  0,
198  "strings",
199  "1-D tensor of strings created by joining row elements from the "
200  "input tensor.");
201 
202 SHOULD_NOT_DO_GRADIENT(StringPrefix);
203 SHOULD_NOT_DO_GRADIENT(StringSuffix);
204 SHOULD_NOT_DO_GRADIENT(StringStartsWith);
205 SHOULD_NOT_DO_GRADIENT(StringEndsWith);
206 SHOULD_NOT_DO_GRADIENT(StringJoin);
207 }
208 } // namespace caffe2
Definition: types.h:88
Copyright (c) 2016-present, Facebook, Inc.