Caffe2 - C++ API
A deep learning, cross platform ML framework
to_batch.h
1 #pragma once
2 
3 #include <torch/csrc/jit/ir.h>
4 #include <torch/csrc/jit/pybind.h>
5 
6 #include <ATen/ATen.h>
7 
8 namespace torch {
9 namespace jit {
10 
11 class ToBatch {
12  private:
13  // number of tensors to represent a expanded BatchTensor. {data, mask, dims}
14  // for now.
15  const size_t EXP_BTENSOR_SIZE = 3;
16  const std::vector<std::string> EXP_BTENSOR_NAME = {"data", "mask", "dims"};
17  // mapping from tensor in original graph to {data, mask, dims} in new graph
18  std::unordered_map<Value*, std::vector<Value*>> batch_map;
19  // mapping from input in original graph to new input in new graph - used in
20  // createClone
21  std::unordered_map<Value*, Value*> rn_env;
22  std::function<Value*(Value*)> rn_fn = [this](Value* v) {
23  return rn_env.at(v);
24  };
25 
26  private:
27  std::shared_ptr<Graph> getBatchOperator(
28  const std::string& name,
29  int64_t input_num = -1);
30  void visitAten(Node* n, Block* block, Block* res_block);
31  void visitConstant(Node* n, Block* block, Block* res_block);
32  void visitNumToTensor(Node* n, Block* block, Block* res_block);
33  void visitTensorToNum(Node* n, Block* block, Block* res_block);
34  void visitListConstruct(Node* n, Block* block, Block* res_block);
35  void visitIf(Node* n, Block* block, Block* res_block);
36  void visitLoop(Node* n, Block* block, Block* res_block);
37 
38  public:
39  static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
40  batch_operator_table;
41  TORCH_API void toBatch(Block* block, Block* res_block);
42 };
43 
44 TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
45 TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
46 } // namespace jit
47 } // namespace torch
Definition: jit_type.h:17