Caffe2 - Python API
A deep learning, cross platform ML framework
Packages
Classes
Files
C++ API
Python API
GitHub
File List
test
bottleneck
test_cuda.py
1
import
torch
2
import
torch.nn
as
nn
3
4
5
class
Model
(nn.Module):
6
def
__init__(self):
7
super(Model, self).__init__()
8
self.
linear
= nn.Linear(20, 20)
9
10
def
forward(self, input):
11
out = self.
linear
(input[:, 10:30])
12
return
out.sum()
13
14
15
def
main():
16
data = torch.randn(10, 50).cuda()
17
model =
Model
().cuda()
18
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
19
for
i
in
range(10):
20
optimizer.zero_grad()
21
loss =
model
(data)
22
loss.backward()
23
optimizer.step()
24
25
26
if
__name__ ==
'__main__'
:
27
main()
test_cuda.Model.linear
linear
Definition:
test_cuda.py:8
model
Definition:
model.py:1
test_cuda.Model
Definition:
test_cuda.py:5
torch.nn
Definition:
__init__.py:1
Generated on Thu Mar 21 2019 13:06:36 for Caffe2 - Python API by
1.8.11