Caffe2 - Python API
A deep learning, cross platform ML framework
debug_embed_params.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import sys
7 import itertools
8 
9 import torch
10 import torch.jit
11 from torch.autograd import Variable
12 import torch.autograd.function as function
13 
14 import onnx
15 import caffe2.python.onnx.backend as c2
16 from test_pytorch_common import flatten
17 
18 
19 torch.set_default_tensor_type('torch.FloatTensor')
20 try:
21  import torch
22 except ImportError:
23  print('Cannot import torch, hence caffe2-torch test will not run.')
24  sys.exit(0)
25 
26 
27 def run_embed_params(proto, model, input, state_dict=None, use_gpu=True):
28  """
29  This is only a helper debug function so we can test embed_params=False
30  case as well on pytorch front
31  This should likely be removed from the release version of the code
32  """
33  device = 'CPU'
34  if use_gpu:
35  device = 'CUDA'
36  model_def = onnx.ModelProto.FromString(proto)
37  onnx.checker.check_model(model_def)
38  prepared = c2.prepare(model_def, device=device)
39 
40  if state_dict:
41  parameters = []
42  # Passed in state_dict may have a different order. Make
43  # sure our order is consistent with the model's order.
44  # TODO: Even better: keyword arguments!
45  for k in model.state_dict():
46  if k not in state_dict:
47  # Once PyTorch Module adds unnecessary paramter, the old pre-trained model does not have it.
48  # Just simply pass the new one.
49  # TODO: Please don't export unnecessary parameter.
50  parameters.append(model.state_dict()[k])
51  else:
52  parameters.append(state_dict[k])
53  else:
54  parameters = list(model.state_dict().values())
55 
56  W = {}
57  for k, v in zip(model_def.graph.input, flatten((input, parameters))):
58  if isinstance(v, Variable):
59  W[k.name] = v.data.cpu().numpy()
60  else:
61  W[k.name] = v.cpu().numpy()
62 
63  caffe2_out = prepared.run(inputs=W)
64 
65  return caffe2_out
def set_default_tensor_type(t)
Definition: __init__.py:132