Caffe2 - C++ API
A deep learning, cross platform ML framework
integration.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/data.h>
4 #include <torch/nn/modules/batchnorm.h>
5 #include <torch/nn/modules/conv.h>
6 #include <torch/nn/modules/dropout.h>
7 #include <torch/nn/modules/linear.h>
8 #include <torch/optim/adam.h>
9 #include <torch/optim/optimizer.h>
10 #include <torch/optim/sgd.h>
11 #include <torch/types.h>
12 #include <torch/utils.h>
13 
14 #include <test/cpp/api/support.h>
15 
16 #include <cmath>
17 #include <cstdlib>
18 #include <random>
19 
20 using namespace torch::nn;
21 using namespace torch::test;
22 
23 const double kPi = 3.1415926535898;
24 
25 class CartPole {
26  // Translated from openai/gym's cartpole.py
27  public:
28  double gravity = 9.8;
29  double masscart = 1.0;
30  double masspole = 0.1;
31  double total_mass = (masspole + masscart);
32  double length = 0.5; // actually half the pole's length;
33  double polemass_length = (masspole * length);
34  double force_mag = 10.0;
35  double tau = 0.02; // seconds between state updates;
36 
37  // Angle at which to fail the episode
38  double theta_threshold_radians = 12 * 2 * kPi / 360;
39  double x_threshold = 2.4;
40  int steps_beyond_done = -1;
41 
42  torch::Tensor state;
43  double reward;
44  bool done;
45  int step_ = 0;
46 
47  torch::Tensor getState() {
48  return state;
49  }
50 
51  double getReward() {
52  return reward;
53  }
54 
55  double isDone() {
56  return done;
57  }
58 
59  void reset() {
60  state = torch::empty({4}).uniform_(-0.05, 0.05);
61  steps_beyond_done = -1;
62  step_ = 0;
63  }
64 
65  CartPole() {
66  reset();
67  }
68 
69  void step(int action) {
70  auto x = state[0].item<float>();
71  auto x_dot = state[1].item<float>();
72  auto theta = state[2].item<float>();
73  auto theta_dot = state[3].item<float>();
74 
75  auto force = (action == 1) ? force_mag : -force_mag;
76  auto costheta = std::cos(theta);
77  auto sintheta = std::sin(theta);
78  auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) /
79  total_mass;
80  auto thetaacc = (gravity * sintheta - costheta * temp) /
81  (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass));
82  auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
83 
84  x = x + tau * x_dot;
85  x_dot = x_dot + tau * xacc;
86  theta = theta + tau * theta_dot;
87  theta_dot = theta_dot + tau * thetaacc;
88  state = torch::tensor({x, x_dot, theta, theta_dot});
89 
90  done = x < -x_threshold || x > x_threshold ||
91  theta < -theta_threshold_radians || theta > theta_threshold_radians ||
92  step_ > 200;
93 
94  if (!done) {
95  reward = 1.0;
96  } else if (steps_beyond_done == -1) {
97  // Pole just fell!
98  steps_beyond_done = 0;
99  reward = 0;
100  } else {
101  if (steps_beyond_done == 0) {
102  AT_ASSERT(false); // Can't do this
103  }
104  }
105  step_++;
106  }
107 };
108 
109 template <typename M, typename F, typename O>
110 bool test_mnist(
111  size_t batch_size,
112  size_t number_of_epochs,
113  bool with_cuda,
114  M&& model,
115  F&& forward_op,
116  O&& optimizer) {
117  std::string mnist_path = "mnist";
118  if (const char* user_mnist_path = getenv("TORCH_CPP_TEST_MNIST_PATH")) {
119  mnist_path = user_mnist_path;
120  }
121 
122  auto train_dataset =
124  mnist_path, torch::data::datasets::MNIST::Mode::kTrain)
126 
127  auto data_loader =
128  torch::data::make_data_loader(std::move(train_dataset), batch_size);
129 
130  torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU);
131  model->to(device);
132 
133  for (size_t epoch = 0; epoch < number_of_epochs; epoch++) {
134  for (torch::data::Example<> batch : *data_loader) {
135  auto data = batch.data.to(device), targets = batch.target.to(device);
136  torch::Tensor prediction = forward_op(std::move(data));
137  torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
138  AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
139  optimizer.zero_grad();
140  loss.backward();
141  optimizer.step();
142  }
143  }
144 
145  torch::NoGradGuard guard;
146  torch::data::datasets::MNIST test_dataset(
147  mnist_path, torch::data::datasets::MNIST::Mode::kTest);
148  auto images = test_dataset.images().to(device),
149  targets = test_dataset.targets().to(device);
150 
151  auto result = std::get<1>(forward_op(images).max(/*dim=*/1));
152  torch::Tensor correct = (result == targets).to(torch::kFloat32);
153  return correct.sum().item<float>() > (test_dataset.size().value() * 0.8);
154 }
155 
157 
158 TEST_F(IntegrationTest, CartPole) {
159  torch::manual_seed(0);
160  auto model = std::make_shared<SimpleContainer>();
161  auto linear = model->add(Linear(4, 128), "linear");
162  auto policyHead = model->add(Linear(128, 2), "policy");
163  auto valueHead = model->add(Linear(128, 1), "action");
164  auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
165 
166  std::vector<torch::Tensor> saved_log_probs;
167  std::vector<torch::Tensor> saved_values;
168  std::vector<float> rewards;
169 
170  auto forward = [&](torch::Tensor inp) {
171  auto x = linear->forward(inp).clamp_min(0);
172  torch::Tensor actions = policyHead->forward(x);
173  torch::Tensor value = valueHead->forward(x);
174  return std::make_tuple(torch::softmax(actions, -1), value);
175  };
176 
177  auto selectAction = [&](torch::Tensor state) {
178  // Only work on single state right now, change index to gather for batch
179  auto out = forward(state);
180  auto probs = torch::Tensor(std::get<0>(out));
181  auto value = torch::Tensor(std::get<1>(out));
182  auto action = probs.multinomial(1)[0].item<int32_t>();
183  // Compute the log prob of a multinomial distribution.
184  // This should probably be actually implemented in autogradpp...
185  auto p = probs / probs.sum(-1, true);
186  auto log_prob = p[action].log();
187  saved_log_probs.emplace_back(log_prob);
188  saved_values.push_back(value);
189  return action;
190  };
191 
192  auto finishEpisode = [&] {
193  auto R = 0.;
194  for (int i = rewards.size() - 1; i >= 0; i--) {
195  R = rewards[i] + 0.99 * R;
196  rewards[i] = R;
197  }
198  auto r_t = torch::from_blob(
199  rewards.data(), {static_cast<int64_t>(rewards.size())});
200  r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
201 
202  std::vector<torch::Tensor> policy_loss;
203  std::vector<torch::Tensor> value_loss;
204  for (auto i = 0U; i < saved_log_probs.size(); i++) {
205  auto r = rewards[i] - saved_values[i].item<float>();
206  policy_loss.push_back(-r * saved_log_probs[i]);
207  value_loss.push_back(
208  torch::smooth_l1_loss(saved_values[i], torch::ones(1) * rewards[i]));
209  }
210 
211  auto loss =
212  torch::stack(policy_loss).sum() + torch::stack(value_loss).sum();
213 
214  optimizer.zero_grad();
215  loss.backward();
216  optimizer.step();
217 
218  rewards.clear();
219  saved_log_probs.clear();
220  saved_values.clear();
221  };
222 
223  auto env = CartPole();
224  double running_reward = 10.0;
225  for (size_t episode = 0;; episode++) {
226  env.reset();
227  auto state = env.getState();
228  int t = 0;
229  for (; t < 10000; t++) {
230  auto action = selectAction(state);
231  env.step(action);
232  state = env.getState();
233  auto reward = env.getReward();
234  auto done = env.isDone();
235 
236  rewards.push_back(reward);
237  if (done)
238  break;
239  }
240 
241  running_reward = running_reward * 0.99 + t * 0.01;
242  finishEpisode();
243  /*
244  if (episode % 10 == 0) {
245  printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
246  episode, t, running_reward);
247  }
248  */
249  if (running_reward > 150) {
250  break;
251  }
252  ASSERT_LT(episode, 3000);
253  }
254 }
255 
256 TEST_F(IntegrationTest, MNIST_CUDA) {
257  torch::manual_seed(0);
258  auto model = std::make_shared<SimpleContainer>();
259  auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
260  auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
261  auto drop = Dropout(0.3);
262  auto drop2d = FeatureDropout(0.3);
263  auto linear1 = model->add(Linear(320, 50), "linear1");
264  auto linear2 = model->add(Linear(50, 10), "linear2");
265 
266  auto forward = [&](torch::Tensor x) {
267  x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
268  x = conv2->forward(x);
269  x = drop2d->forward(x);
270  x = torch::max_pool2d(x, {2, 2}).relu();
271 
272  x = x.view({-1, 320});
273  x = linear1->forward(x).clamp_min(0);
274  x = drop->forward(x);
275  x = linear2->forward(x);
276  x = torch::log_softmax(x, 1);
277  return x;
278  };
279 
280  auto optimizer = torch::optim::SGD(
281  model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
282 
283  ASSERT_TRUE(test_mnist(
284  32, // batch_size
285  3, // number_of_epochs
286  true, // with_cuda
287  model,
288  forward,
289  optimizer));
290 }
291 
292 TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
293  torch::manual_seed(0);
294  auto model = std::make_shared<SimpleContainer>();
295  auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
296  auto batchnorm2d = model->add(BatchNorm(10), "batchnorm2d");
297  auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
298  auto linear1 = model->add(Linear(320, 50), "linear1");
299  auto batchnorm1 = model->add(BatchNorm(50), "batchnorm1");
300  auto linear2 = model->add(Linear(50, 10), "linear2");
301 
302  auto forward = [&](torch::Tensor x) {
303  x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
304  x = batchnorm2d->forward(x);
305  x = conv2->forward(x);
306  x = torch::max_pool2d(x, {2, 2}).relu();
307 
308  x = x.view({-1, 320});
309  x = linear1->forward(x).clamp_min(0);
310  x = batchnorm1->forward(x);
311  x = linear2->forward(x);
312  x = torch::log_softmax(x, 1);
313  return x;
314  };
315 
316  auto optimizer = torch::optim::SGD(
317  model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
318 
319  ASSERT_TRUE(test_mnist(
320  32, // batch_size
321  3, // number_of_epochs
322  true, // with_cuda
323  model,
324  forward,
325  optimizer));
326 }
optional< size_t > size() const override
Returns the size of the dataset.
Definition: mnist.cpp:107
void backward(c10::optional< Tensor > gradient=c10::nullopt, bool keep_graph=false, bool create_graph=false)
Computes the gradient of current tensor w.r.t. graph leaves.
Definition: TensorMethods.h:49
Definition: any.cpp:108
MapDataset< Self, TransformType > map(TransformType transform)&
Creates a MapDataset that applies the given transform to this dataset.
Definition: base.h:57
An Example from a dataset.
Definition: example.h:12
const Tensor & targets() const
Returns all targets stacked into a single tensor.
Definition: mnist.cpp:119
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
The MNIST dataset.
Definition: mnist.h:16
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
Definition: module.cpp:244
const Tensor & images() const
Returns all images stacked into a single tensor.
Definition: mnist.cpp:115