3 #include <torch/csrc/jit/ir.h> 4 #include <torch/csrc/jit/pybind.h> 15 const size_t EXP_BTENSOR_SIZE = 3;
16 const std::vector<std::string> EXP_BTENSOR_NAME = {
"data",
"mask",
"dims"};
18 std::unordered_map<Value*, std::vector<Value*>> batch_map;
21 std::unordered_map<Value*, Value*> rn_env;
22 std::function<Value*(Value*)> rn_fn = [
this](
Value* v) {
27 std::shared_ptr<Graph> getBatchOperator(
28 const std::string& name,
29 int64_t input_num = -1);
39 static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
41 TORCH_API
void toBatch(
Block* block,
Block* res_block);
44 TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
45 TORCH_API
void initRegisterBatchOpsBindings(PyObject* module);