3 from __future__
import absolute_import, division, print_function, unicode_literals
12 Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf 21 name=
"adaptive_weight",
24 enable_diagnose=
False,
25 estimation_method=
"log_std",
26 pos_optim_method=
"log_barrier",
30 super(AdaptiveWeight, self).__init__(model, name, input_record, **kwargs)
34 self.
data = self.input_record.field_blobs()
37 if weights
is not None:
38 assert len(weights) == self.
num 40 weights = [1. / self.
num for _
in range(self.
num)]
41 assert min(weights) > 0,
"initial weights must be positive" 42 self.
weights = np.array(weights).astype(np.float32)
56 for i
in range(self.
num)
58 for i
in range(self.
num):
59 self.model.add_ad_hoc_plot_blob(self.
weight_i[i])
61 def concat_data(self, net):
62 reshaped = [net.NextScopedBlob(
"reshaped_data_%d" % i)
for i
in range(self.
num)]
64 for i
in range(self.
num):
67 [reshaped[i], net.NextScopedBlob(
"new_shape_%d" % i)],
70 concated = net.NextScopedBlob(
"concated_data")
72 reshaped, [concated, net.NextScopedBlob(
"concated_new_shape")], axis=0
78 mu = 2 log sigma, sigma = standard variance 80 min 1 / 2 / e^mu X + mu / 2 82 values = np.log(1. / 2. / self.
weights)
85 {
"values": values,
"dtype": core.DataType.FLOAT},
90 initializer=initializer,
96 min 1 / 2 / e^mu X + mu / 2 98 mu_neg = net.NextScopedBlob(
"mu_neg")
99 net.Negative(self.
mu, mu_neg)
100 mu_neg_exp = net.NextScopedBlob(
"mu_neg_exp")
101 net.Exp(mu_neg, mu_neg_exp)
102 net.Scale(mu_neg_exp, weight, scale=0.5)
104 def log_std_reg(self, net, reg):
105 net.Scale(self.
mu, reg, scale=0.5)
111 min 1 / 2 * k X - 1 / 2 * log k 116 {
"values": values,
"dtype": core.DataType.FLOAT},
124 "unknown positivity optimization method: {}".format(
131 initializer=initializer,
133 regularizer=regularizer,
136 def inv_var_weight(self, x, net, weight):
137 net.Scale(self.
k, weight, scale=0.5)
139 def inv_var_reg(self, net, reg):
140 log_k = net.NextScopedBlob(
"log_k")
141 net.Log(self.
k, log_k)
142 net.Scale(log_k, reg, scale=-0.5)
144 def _add_ops_impl(self, net, enable_diagnose):
146 weight = net.NextScopedBlob(
"weight")
147 reg = net.NextScopedBlob(
"reg")
148 weighted_x = net.NextScopedBlob(
"weighted_x")
149 weighted_x_add_reg = net.NextScopedBlob(
"weighted_x_add_reg")
152 net.Mul([weight, x], weighted_x)
153 net.Add([weighted_x, reg], weighted_x_add_reg)
156 for i
in range(self.
num):
157 net.Slice(weight, self.
weight_i[i], starts=[i], ends=[i + 1])
159 def add_ops(self, net):
def get_next_blob_reference(self, name)
def concat_data(self, net)
def _add_ops_impl(self, net, enable_diagnose)
def log_std_weight(self, x, net, weight)
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)