Caffe2 - C++ API
A deep learning, cross platform ML framework
string_ops.cc
1 #include "caffe2/operators/string_ops.h"
2 #include "caffe2/core/operator.h"
3 
4 namespace caffe2 {
5 
6 template <>
7 template <typename T>
8 bool StringJoinOp<CPUContext>::DoRunWithType() {
9  const auto& input = Input(0);
10 
11  CAFFE_ENFORCE_GT(input.numel(), 0);
12  CAFFE_ENFORCE_LE(input.dim(), 2, "Only 1-D and 2-D tensors are supported");
13 
14  const auto* inputData = input.data<T>();
15  int rowSize = (input.dim() == 2) ? input.size(1) : 1;
16  if (this->axis_ == 0) {
17  auto* output = Output(0, {input.size(0)}, at::dtype<std::string>());
18  auto* outputData = output->template mutable_data<std::string>();
19 
20  int offset = 0;
21  for (int i = 0; i < input.size(0); ++i) {
22  std::stringstream stream;
23  std::copy(
24  inputData + offset,
25  inputData + offset + rowSize,
26  std::ostream_iterator<T>(stream, delimiter_.c_str()));
27  outputData[i] = stream.str();
28  offset += rowSize;
29  }
30  } else if (this->axis_ == 1) {
31  auto* output = Output(0, {input.size(1)}, at::dtype<std::string>());
32  auto* outputData = output->template mutable_data<std::string>();
33 
34  for (int j = 0; j < input.size(1); ++j) {
35  std::stringstream stream;
36  for (int i = 0; i < input.size(0); ++i) {
37  stream << inputData[i * rowSize + j] << delimiter_;
38  }
39  outputData[j] = stream.str();
40  }
41  } else {
42  CAFFE_ENFORCE(false, "Not supported");
43  }
44 
45  return true;
46 }
47 
48 namespace {
49 
50 struct StartsWith {
51  explicit StartsWith(OperatorBase& op)
52  : prefix_(op.GetSingleArgument<std::string>("prefix", "")) {}
53  bool operator()(const std::string& str) {
54  return std::mismatch(prefix_.begin(), prefix_.end(), str.begin()).first ==
55  prefix_.end();
56  }
57 
58  private:
59  std::string prefix_;
60 };
61 
62 struct EndsWith {
63  explicit EndsWith(OperatorBase& op)
64  : suffix_(op.GetSingleArgument<std::string>("suffix", "")) {}
65  bool operator()(const std::string& str) {
66  return std::mismatch(suffix_.rbegin(), suffix_.rend(), str.rbegin())
67  .first == suffix_.rend();
68  }
69 
70  private:
71  std::string suffix_;
72 };
73 
74 struct Prefix {
75  explicit Prefix(OperatorBase& op)
76  : length_(op.GetSingleArgument<int>("length", 3)) {}
77  std::string operator()(const std::string& str) {
78  return std::string(str.begin(), std::min(str.end(), str.begin() + length_));
79  }
80 
81  private:
82  int length_;
83 };
84 
85 struct Suffix {
86  explicit Suffix(OperatorBase& op)
87  : length_(op.GetSingleArgument<int>("length", 3)) {}
88  std::string operator()(const std::string& str) {
89  return std::string(std::max(str.begin(), str.end() - length_), str.end());
90  }
91 
92  private:
93  int length_;
94 };
95 
96 template <typename ScalarFunctor, typename TypeMap = FixedType<std::string>>
97 using StringElementwiseOp = UnaryElementwiseWithArgsOp<
98  TensorTypes<std::string>,
99  CPUContext,
100  ForEach<ScalarFunctor>,
101  TypeMap>;
102 
103 REGISTER_CPU_OPERATOR(StringPrefix, StringElementwiseOp<Prefix>);
104 REGISTER_CPU_OPERATOR(StringSuffix, StringElementwiseOp<Suffix>);
105 REGISTER_CPU_OPERATOR(
106  StringStartsWith,
107  StringElementwiseOp<StartsWith, FixedType<bool>>);
108 REGISTER_CPU_OPERATOR(
109  StringEndsWith,
110  StringElementwiseOp<EndsWith, FixedType<bool>>);
111 REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp<CPUContext>);
112 
113 OPERATOR_SCHEMA(StringPrefix)
114  .NumInputs(1)
115  .NumOutputs(1)
116  .SetDoc(R"DOC(
117 Computes the element-wise string prefix of the string tensor.
118 Input strings that are shorter than prefix length will be returned unchanged.
119 NOTE: Prefix is computed on number of bytes, which may lead to wrong behavior
120 and potentially invalid strings for variable-length encodings such as utf-8.
121 )DOC")
122  .Arg("length", "Maximum size of the prefix, in bytes.")
123  .Input(0, "strings", "Tensor of std::string.")
124  .Output(
125  0,
126  "prefixes",
127  "Tensor of std::string containing prefixes for each input.");
128 
129 OPERATOR_SCHEMA(StringSuffix)
130  .NumInputs(1)
131  .NumOutputs(1)
132  .SetDoc(R"DOC(
133 Computes the element-wise string suffix of the string tensor.
134 Input strings that are shorter than suffix length will be returned unchanged.
135 NOTE: Prefix is computed on number of bytes, which may lead to wrong behavior
136 and potentially invalid strings for variable-length encodings such as utf-8.
137 )DOC")
138  .Input(0, "strings", "Tensor of std::string.")
139  .Output(
140  0,
141  "suffixes",
142  "Tensor of std::string containing suffixes for each output.")
143  .Arg("length", "Maximum size of the suffix, in bytes.");
144 
145 OPERATOR_SCHEMA(StringStartsWith)
146  .NumInputs(1)
147  .NumOutputs(1)
148  .SetDoc(R"DOC(
149 Performs the starts-with check on each string in the input tensor.
150 Returns tensor of boolean of the same dimension of input.
151 )DOC")
152  .Arg("prefix", "The prefix to check input strings against.")
153  .Input(0, "strings", "Tensor of std::string.")
154  .Output(0, "bools", "Tensor of bools of same shape as input.");
155 
156 OPERATOR_SCHEMA(StringEndsWith)
157  .NumInputs(1)
158  .NumOutputs(1)
159  .SetDoc(R"DOC(
160 Performs the ends-with check on each string in the input tensor.
161 Returns tensor of boolean of the same dimension of input.
162 )DOC")
163  .Arg("suffix", "The suffix to check input strings against.")
164  .Input(0, "strings", "Tensor of std::string.")
165  .Output(0, "bools", "Tensor of bools of same shape as input.");
166 
167 OPERATOR_SCHEMA(StringJoin)
168  .NumInputs(1)
169  .NumOutputs(1)
170  .SetDoc(R"DOC(
171 Takes a 1-D or a 2-D tensor as input and joins elements in each row with the
172 provided delimiter. Output is a 1-D tensor of size equal to the first dimension
173 of the input. Each element in the output tensor is a string of concatenated
174 elements corresponding to each row in the input tensor. For 1-D input, each
175 element is treated as a row.
176 )DOC")
177  .Arg("delimiter", "Delimiter for join (Default: \",\").")
178  .Arg("axis", "Axis for the join (either 0 or 1)")
179  .Input(0, "input", "1-D or 2-D tensor")
180  .Output(
181  0,
182  "strings",
183  "1-D tensor of strings created by joining row elements from the "
184  "input tensor.");
185 
186 SHOULD_NOT_DO_GRADIENT(StringPrefix);
187 SHOULD_NOT_DO_GRADIENT(StringSuffix);
188 SHOULD_NOT_DO_GRADIENT(StringStartsWith);
189 SHOULD_NOT_DO_GRADIENT(StringEndsWith);
190 SHOULD_NOT_DO_GRADIENT(StringJoin);
191 }
192 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13