3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python import core, workspace, model_helper, utils, brew
10 from caffe2.proto
import caffe2_pb2
17 from datetime
import datetime
20 This script takes a text file as input and uses a recurrent neural network 21 to learn to predict next character in a sequence. 25 log = logging.getLogger(
"char_rnn")
26 log.setLevel(logging.DEBUG)
31 def CreateNetOnce(net, created_names=set()):
33 if name
not in created_names:
34 created_names.add(name)
35 workspace.CreateNet(net)
39 def __init__(self, args):
45 with open(args.train_data)
as f:
53 print(
"Input has {} characters. Total input size: {}".format(
56 def CreateModel(self):
57 log.debug(
"Start training")
60 input_blob, seq_lengths, hidden_init, cell_init, target = \
61 model.net.AddExternalInputs(
69 hidden_output_all, self.hidden_output, _, self.
cell_state = LSTM(
70 model, input_blob, seq_lengths, (hidden_init, cell_init),
83 softmax = model.net.Softmax(output,
'softmax', axis=2)
85 softmax_reshaped, _ = model.net.Reshape(
86 softmax, [
'softmax_reshaped',
'_'], shape=[-1, self.
D])
92 xent = model.net.LabelCrossEntropy([softmax_reshaped, target],
'xent')
95 loss = model.net.AveragedLoss(xent,
'loss')
96 model.AddGradientOperators([loss])
112 self.prepare_state.Copy(self.hidden_output, hidden_init)
113 self.prepare_state.Copy(self.
cell_state, cell_init)
115 def _idx_at_pos(self, pos):
118 def TrainModel(self):
119 log.debug(
"Training model")
121 workspace.RunNetOnce(self.model.param_init_net)
124 smooth_loss = -np.log(1.0 / self.
D) * self.
seq_length 132 text_block_positions = np.zeros(self.
batch_size, dtype=np.int32)
134 text_block_starts = list(range(0, N, text_block_size))
135 text_block_sizes = [text_block_size] * self.
batch_size 137 assert sum(text_block_sizes) == N
141 workspace.FeedBlob(self.hidden_output, np.zeros(
151 last_time = datetime.now()
159 workspace.RunNet(self.prepare_state.Name())
170 pos = text_block_starts[e] + text_block_positions[e]
174 text_block_positions[e] = (
175 text_block_positions[e] + 1) % text_block_sizes[e]
178 workspace.FeedBlob(
'input_blob', input)
179 workspace.FeedBlob(
'target', target)
181 CreateNetOnce(self.model.net)
182 workspace.RunNet(self.model.net.Name())
188 new_time = datetime.now()
189 print(
"Characters Per Second: {}". format(
190 int(progress / (new_time - last_time).total_seconds())
192 print(
"Iterations Per Second: {}". format(
194 (new_time - last_time).total_seconds())
200 print(
"{} Iteration {} {}".
201 format(
'-' * 10, num_iter,
'-' * 10))
204 smooth_loss = 0.999 * smooth_loss + 0.001 * loss
210 log.debug(
"Loss since last report: {}" 211 .format(last_n_loss / last_n_iter))
212 log.debug(
"Smooth loss: {}".format(smooth_loss))
217 def GenerateText(self, num_characters, ch):
225 for _i
in range(num_characters):
227 "seq_lengths", np.array([1] * self.
batch_size, dtype=np.int32))
228 workspace.RunNet(self.prepare_state.Name())
230 input = np.zeros([1, self.
batch_size, self.
D]).astype(np.float32)
233 workspace.FeedBlob(
"input_blob", input)
234 workspace.RunNet(self.forward_net.Name())
237 next = np.random.choice(self.
D, p=p[0][0])
247 parser = argparse.ArgumentParser(
248 description=
"Caffe2: Char RNN Training" 250 parser.add_argument(
"--train_data", type=str, default=
None,
251 help=
"Path to training data in a text file format",
253 parser.add_argument(
"--seq_length", type=int, default=25,
254 help=
"One training example sequence length")
255 parser.add_argument(
"--batch_size", type=int, default=1,
256 help=
"Training batch size")
257 parser.add_argument(
"--iters_to_report", type=int, default=500,
258 help=
"How often to report loss and generate text")
259 parser.add_argument(
"--hidden_size", type=int, default=100,
260 help=
"Dimension of the hidden representation")
261 parser.add_argument(
"--gpu", action=
"store_true",
262 help=
"If set, training is going to use GPU 0")
264 args = parser.parse_args()
266 device = core.DeviceOption(
267 workspace.GpuDeviceType
if args.gpu
else caffe2_pb2.CPU, 0)
268 with core.DeviceScope(device):
274 if __name__ ==
'__main__':
275 workspace.GlobalInit([
'caffe2',
'--caffe2_log_level=2'])
def _idx_at_pos(self, pos)
def GenerateText(self, num_characters, ch)