Caffe2 - C++ API
A deep learning, cross platform ML framework
registry_test.cpp
1 #include <gtest/gtest.h>
2 #include <iostream>
3 #include <memory>
4 
5 #include <c10/util/Registry.h>
6 
7 // Note: we use a different namespace to test if the macros defined in
8 // Registry.h actuall works with a different namespace from c10.
9 namespace c10_test {
10 
11 class Foo {
12  public:
13  explicit Foo(int x) {
14  // LOG(INFO) << "Foo " << x;
15  }
16  virtual ~Foo() {}
17 };
18 
19 C10_DECLARE_REGISTRY(FooRegistry, Foo, int);
20 C10_DEFINE_REGISTRY(FooRegistry, Foo, int);
21 #define REGISTER_FOO(clsname) C10_REGISTER_CLASS(FooRegistry, clsname, clsname)
22 
23 class Bar : public Foo {
24  public:
25  explicit Bar(int x) : Foo(x) {
26  // LOG(INFO) << "Bar " << x;
27  }
28 };
29 REGISTER_FOO(Bar);
30 
31 class AnotherBar : public Foo {
32  public:
33  explicit AnotherBar(int x) : Foo(x) {
34  // LOG(INFO) << "AnotherBar " << x;
35  }
36 };
37 REGISTER_FOO(AnotherBar);
38 
39 TEST(RegistryTest, CanRunCreator) {
40  std::unique_ptr<Foo> bar(FooRegistry()->Create("Bar", 1));
41  EXPECT_TRUE(bar != nullptr) << "Cannot create bar.";
42  std::unique_ptr<Foo> another_bar(FooRegistry()->Create("AnotherBar", 1));
43  EXPECT_TRUE(another_bar != nullptr);
44 }
45 
46 TEST(RegistryTest, ReturnNullOnNonExistingCreator) {
47  EXPECT_EQ(FooRegistry()->Create("Non-existing bar", 1), nullptr);
48 }
49 
50 // C10_REGISTER_CLASS_WITH_PRIORITY defines static variable
51 void RegisterFooDefault() {
52  C10_REGISTER_CLASS_WITH_PRIORITY(
53  FooRegistry, FooWithPriority, c10::REGISTRY_DEFAULT, Foo);
54 }
55 
56 void RegisterFooDefaultAgain() {
57  C10_REGISTER_CLASS_WITH_PRIORITY(
58  FooRegistry, FooWithPriority, c10::REGISTRY_DEFAULT, Foo);
59 }
60 
61 void RegisterFooBarFallback() {
62  C10_REGISTER_CLASS_WITH_PRIORITY(
63  FooRegistry, FooWithPriority, c10::REGISTRY_FALLBACK, Bar);
64 }
65 
66 void RegisterFooBarPreferred() {
67  C10_REGISTER_CLASS_WITH_PRIORITY(
68  FooRegistry, FooWithPriority, c10::REGISTRY_PREFERRED, Bar);
69 }
70 
71 TEST(RegistryTest, RegistryPriorities) {
72  FooRegistry()->SetTerminate(false);
73  RegisterFooDefault();
74 
75  // throws because Foo is already registered with default priority
76  EXPECT_THROW(RegisterFooDefaultAgain(), std::runtime_error);
77 
78 #ifdef __GXX_RTTI
79  // not going to register Bar because Foo is registered with Default priority
80  RegisterFooBarFallback();
81  std::unique_ptr<Foo> bar1(FooRegistry()->Create("FooWithPriority", 1));
82  EXPECT_EQ(dynamic_cast<Bar*>(bar1.get()), nullptr);
83 
84  // will register Bar because of higher priority
85  RegisterFooBarPreferred();
86  std::unique_ptr<Foo> bar2(FooRegistry()->Create("FooWithPriority", 1));
87  EXPECT_NE(dynamic_cast<Bar*>(bar2.get()), nullptr);
88 #endif
89 }
90 
91 } // namespace c10_test