Caffe2 - Python API
A deep learning, cross platform ML framework
test_cuda.py
1 import torch
2 import torch.nn as nn
3 
4 
5 class Model(nn.Module):
6  def __init__(self):
7  super(Model, self).__init__()
8  self.linear = nn.Linear(20, 20)
9 
10  def forward(self, input):
11  out = self.linear(input[:, 10:30])
12  return out.sum()
13 
14 
15 def main():
16  data = torch.randn(10, 50).cuda()
17  model = Model().cuda()
18  optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
19  for i in range(10):
20  optimizer.zero_grad()
21  loss = model(data)
22  loss.backward()
23  optimizer.step()
24 
25 
26 if __name__ == '__main__':
27  main()
Definition: model.py:1