Caffe2 - C++ API
A deep learning, cross platform ML framework
test_util.h
1 #ifndef NOM_TESTS_TEST_UTIL_H
2 #define NOM_TESTS_TEST_UTIL_H
3 
4 #include "caffe2/core/common.h"
5 #include "nomnigraph/Graph/Graph.h"
6 #include "nomnigraph/Graph/Algorithms.h"
7 #include "nomnigraph/Representations/NeuralNet.h"
8 #include "nomnigraph/Converters/Dot.h"
9 
10 #include <map>
11 
12 class TestClass {
13 public:
14  TestClass() {}
15  ~TestClass() {}
16 };
17 
18 struct NNEquality {
19  static bool equal(
20  const typename nom::repr::NNGraph::NodeRef& a,
21  const typename nom::repr::NNGraph::NodeRef& b) {
22  if (
23  !nom::repr::nn::is<nom::repr::NeuralNetOperator>(a) ||
24  !nom::repr::nn::is<nom::repr::NeuralNetOperator>(b)) {
25  return false;
26  }
27  auto a_ = nom::repr::nn::get<nom::repr::NeuralNetOperator>(a);
28  auto b_ = nom::repr::nn::get<nom::repr::NeuralNetOperator>(b);
29 
30  bool sameKind = a_->getKind() == b_->getKind();
31  if (sameKind && a_->getKind() == nom::repr::NeuralNetOperator::NNKind::GenericOperator) {
32  return a_->getName() == b_->getName();
33  }
34  return sameKind;
35  }
36 };
37 
38 // Very simple random number generator used to generate platform independent
39 // random test data.
40 class TestRandom {
41  public:
42  TestRandom(unsigned int seed) : seed_(seed){};
43 
44  unsigned int nextInt() {
45  seed_ = A * seed_ + C;
46  return seed_;
47  }
48 
49  private:
50  static const unsigned int A = 1103515245;
51  static const unsigned int C = 12345;
52  unsigned int seed_;
53 };
54 
105 CAFFE2_API nom::Graph<std::string> createGraph();
106 
107 CAFFE2_API nom::Graph<std::string> createGraphWithCycle();
108 
109 std::map<std::string, std::string> BBPrinter(typename nom::repr::NNCFGraph::NodeRef node);
110 
111 std::map<std::string, std::string> cfgEdgePrinter(typename nom::repr::NNCFGraph::EdgeRef edge);
112 
113 std::map<std::string, std::string> NNPrinter(typename nom::repr::NNGraph::NodeRef node);
114 
115 CAFFE2_API nom::Graph<TestClass>::NodeRef createTestNode(
117 
118 CAFFE2_API std::map<std::string, std::string> TestNodePrinter(
120 #endif // NOM_TESTS_TEST_UTIL_H
Definition: static.cpp:52
Definition: static.cpp:64
A simple graph implementation.
Definition: Graph.h:29