Caffe2 - Python API
A deep learning, cross platform ML framework
1 import torch
2 import torch.nn as nn
5 # configurable
6 bsz = 64
7 imgsz = 64
8 nz = 100
9 ngf = 64
10 ndf = 64
11 nc = 3
14 # custom weights initialization called on netG and netD
15 def weights_init(m):
16  classname = m.__class__.__name__
17  if classname.find('Conv') != -1:
18, 0.02)
19  elif classname.find('BatchNorm') != -1:
20, 0.02)
24 class _netG(nn.Module):
25  def __init__(self, ngpu):
26  super(_netG, self).__init__()
27  self.ngpu = ngpu
28  self.main = nn.Sequential(
29  # input is Z, going into a convolution
30  nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
31  nn.BatchNorm2d(ngf * 8),
32  nn.ReLU(True),
33  # state size. (ngf*8) x 4 x 4
34  nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
35  nn.BatchNorm2d(ngf * 4),
36  nn.ReLU(True),
37  # state size. (ngf*4) x 8 x 8
38  nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
39  nn.BatchNorm2d(ngf * 2),
40  nn.ReLU(True),
41  # state size. (ngf*2) x 16 x 16
42  nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
43  nn.BatchNorm2d(ngf),
44  nn.ReLU(True),
45  # state size. (ngf) x 32 x 32
46  nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
47  nn.Tanh()
48  # state size. (nc) x 64 x 64
49  )
51  def forward(self, input):
52  if self.ngpu > 1 and isinstance(, torch.cuda.FloatTensor):
53  output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
54  else:
55  output = self.main(input)
56  return output
59 class _netD(nn.Module):
60  def __init__(self, ngpu):
61  super(_netD, self).__init__()
62  self.ngpu = ngpu
63  self.main = nn.Sequential(
64  # input is (nc) x 64 x 64
65  nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
66  nn.LeakyReLU(0.2, inplace=True),
67  # state size. (ndf) x 32 x 32
68  nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
69  nn.BatchNorm2d(ndf * 2),
70  nn.LeakyReLU(0.2, inplace=True),
71  # state size. (ndf*2) x 16 x 16
72  nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
73  nn.BatchNorm2d(ndf * 4),
74  nn.LeakyReLU(0.2, inplace=True),
75  # state size. (ndf*4) x 8 x 8
76  nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
77  nn.BatchNorm2d(ndf * 8),
78  nn.LeakyReLU(0.2, inplace=True),
79  # state size. (ndf*8) x 4 x 4
80  nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
81  nn.Sigmoid()
82  )
84  def forward(self, input):
85  if self.ngpu > 1 and isinstance(, torch.cuda.FloatTensor):
86  output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
87  else:
88  output = self.main(input)
90  return output.view(-1, 1)