| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| from caffe2.python import core, scope |
| from caffe2.python.model_helper import ModelHelperBase |
| from caffe2.proto import caffe2_pb2 |
| |
| |
| class CNNModelHelper(ModelHelperBase): |
| """A helper model so we can write CNN models more easily, without having to |
| manually define parameter initializations and operators separately. |
| """ |
| |
| def __init__(self, order="NCHW", name=None, |
| use_cudnn=True, cudnn_exhaustive_search=False, |
| ws_nbytes_limit=None, init_params=True, |
| skip_sparse_optim=False, |
| param_model=None): |
| |
| super(CNNModelHelper, self).__init__( |
| skip_sparse_optim=skip_sparse_optim, |
| name="CNN" if name is None else name, |
| init_params=init_params, |
| param_model=param_model, |
| ) |
| |
| self.weights = [] |
| self.biases = [] |
| self.order = order |
| self.use_cudnn = use_cudnn |
| self.cudnn_exhaustive_search = cudnn_exhaustive_search |
| self.ws_nbytes_limit = ws_nbytes_limit |
| if self.order != "NHWC" and self.order != "NCHW": |
| raise ValueError( |
| "Cannot understand the CNN storage order %s." % self.order |
| ) |
| |
| def GetWeights(self, namescope=None): |
| if namescope is None: |
| namescope = scope.CurrentNameScope() |
| |
| if namescope == '': |
| return self.weights[:] |
| else: |
| return [w for w in self.weights if w.GetNameScope() == namescope] |
| |
| def GetBiases(self, namescope=None): |
| if namescope is None: |
| namescope = scope.CurrentNameScope() |
| |
| if namescope == '': |
| return self.biases[:] |
| else: |
| return [b for b in self.biases if b.GetNameScope() == namescope] |
| |
| def ImageInput( |
| self, blob_in, blob_out, use_gpu_transform=False, **kwargs |
| ): |
| """Image Input.""" |
| if self.order == "NCHW": |
| if (use_gpu_transform): |
| kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0 |
| # GPU transform will handle NHWC -> NCHW |
| data, label = self.net.ImageInput( |
| blob_in, [blob_out[0], blob_out[1]], **kwargs) |
| # data = self.net.Transform(data, blob_out[0], **kwargs) |
| pass |
| else: |
| data, label = self.net.ImageInput( |
| blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs) |
| data = self.net.NHWC2NCHW(data, blob_out[0]) |
| else: |
| data, label = self.net.ImageInput( |
| blob_in, blob_out, **kwargs) |
| return data, label |
| |
| def Conv( |
| self, blob_in, blob_out, dim_in, dim_out, kernel, weight_init=None, |
| bias_init=None, group=1, **kwargs |
| ): |
| """Convolution. We intentionally do not provide odd kernel/stride/pad |
| settings in order to discourage the use of odd cases. |
| """ |
| use_bias = False if ("no_bias" in kwargs and kwargs["no_bias"]) else True |
| weight_init = weight_init if weight_init else ('XavierFill', {}) |
| bias_init = bias_init if bias_init else ('ConstantFill', {}) |
| blob_out = blob_out or self.net.NextName() |
| weight_shape = ( |
| [dim_out, int(dim_in / group), kernel, kernel] |
| if self.order == "NCHW" else |
| [dim_out, kernel, kernel, int(dim_in / group)] |
| ) |
| if self.init_params: |
| weight = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_w', |
| shape=weight_shape, |
| **weight_init[1] |
| ) |
| if use_bias: |
| bias = self.param_init_net.__getattr__(bias_init[0])( |
| [], |
| blob_out + '_b', |
| shape=[dim_out, ], |
| **bias_init[1] |
| ) |
| else: |
| weight = core.ScopedBlobReference( |
| blob_out + '_w', self.param_init_net) |
| if use_bias: |
| bias = core.ScopedBlobReference( |
| blob_out + '_b', self.param_init_net) |
| if use_bias: |
| self.params.extend([weight, bias]) |
| else: |
| self.params.extend([weight]) |
| |
| self.weights.append(weight) |
| |
| if use_bias: |
| self.biases.append(bias) |
| |
| if self.use_cudnn: |
| kwargs['engine'] = 'CUDNN' |
| kwargs['exhaustive_search'] = self.cudnn_exhaustive_search |
| if self.ws_nbytes_limit: |
| kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit |
| |
| inputs = [] |
| if use_bias: |
| inputs = [blob_in, weight, bias] |
| else: |
| inputs = [blob_in, weight] |
| |
| # For the operator, we no longer need to provide the no_bias field |
| # because it can automatically figure this out from the number of |
| # inputs. |
| if 'no_bias' in kwargs: |
| del kwargs['no_bias'] |
| if group != 1: |
| kwargs['group'] = group |
| return self.net.Conv( |
| inputs, |
| blob_out, |
| kernel=kernel, |
| order=self.order, |
| **kwargs |
| ) |
| |
| def ConvTranspose( |
| self, blob_in, blob_out, dim_in, dim_out, kernel, weight_init=None, |
| bias_init=None, **kwargs |
| ): |
| """ConvTranspose. |
| """ |
| weight_init = weight_init if weight_init else ('XavierFill', {}) |
| bias_init = bias_init if bias_init else ('ConstantFill', {}) |
| blob_out = blob_out or self.net.NextName() |
| weight_shape = ( |
| [dim_in, dim_out, kernel, kernel] |
| if self.order == "NCHW" else [dim_in, kernel, kernel, dim_out] |
| ) |
| if self.init_params: |
| weight = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_w', |
| shape=weight_shape, |
| **weight_init[1] |
| ) |
| bias = self.param_init_net.__getattr__(bias_init[0])( |
| [], |
| blob_out + '_b', |
| shape=[dim_out, ], |
| **bias_init[1] |
| ) |
| else: |
| weight = core.ScopedBlobReference( |
| blob_out + '_w', self.param_init_net) |
| bias = core.ScopedBlobReference( |
| blob_out + '_b', self.param_init_net) |
| self.params.extend([weight, bias]) |
| self.weights.append(weight) |
| self.biases.append(bias) |
| if self.use_cudnn: |
| kwargs['engine'] = 'CUDNN' |
| kwargs['exhaustive_search'] = self.cudnn_exhaustive_search |
| if self.ws_nbytes_limit: |
| kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit |
| return self.net.ConvTranspose( |
| [blob_in, weight, bias], |
| blob_out, |
| kernel=kernel, |
| order=self.order, |
| **kwargs |
| ) |
| |
| |
| def GroupConv( |
| self, |
| blob_in, |
| blob_out, |
| dim_in, |
| dim_out, |
| kernel, |
| weight_init=None, |
| bias_init=None, |
| group=1, |
| **kwargs |
| ): |
| """Group Convolution. |
| |
| This is essentially the same as Conv with a group argument passed in. |
| We specialize this for backward interface compatibility. |
| """ |
| return self.Conv(blob_in, blob_out, dim_in, dim_out, kernel, |
| weight_init=weight_init, bias_init=bias_init, |
| group=group, **kwargs) |
| |
| def GroupConv_Deprecated( |
| self, |
| blob_in, |
| blob_out, |
| dim_in, |
| dim_out, |
| kernel, |
| weight_init=None, |
| bias_init=None, |
| group=1, |
| **kwargs |
| ): |
| """GroupConvolution's deprecated interface. |
| |
| This is used to simulate a group convolution via split and concat. You |
| should always use the new group convolution in your new code. |
| """ |
| weight_init = weight_init if weight_init else ('XavierFill', {}) |
| bias_init = bias_init if bias_init else ('ConstantFill', {}) |
| use_bias = False if ("no_bias" in kwargs and kwargs["no_bias"]) else True |
| if self.use_cudnn: |
| kwargs['engine'] = 'CUDNN' |
| kwargs['exhaustive_search'] = self.cudnn_exhaustive_search |
| if self.ws_nbytes_limit: |
| kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit |
| if dim_in % group: |
| raise ValueError("dim_in should be divisible by group.") |
| if dim_out % group: |
| raise ValueError("dim_out should be divisible by group.") |
| splitted_blobs = self.net.DepthSplit( |
| blob_in, |
| ['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)], |
| dimensions=[int(dim_in / group) for i in range(group)], |
| order=self.order |
| ) |
| weight_shape = ( |
| [dim_out / group, dim_in / group, kernel, kernel] |
| if self.order == "NCHW" else |
| [dim_out / group, kernel, kernel, dim_in / group] |
| ) |
| # Make sure that the shapes are of int format. Especially for py3 where |
| # int division gives float output. |
| weight_shape = [int(v) for v in weight_shape] |
| conv_blobs = [] |
| for i in range(group): |
| if self.init_params: |
| weight = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_gconv_%d_w' % i, |
| shape=weight_shape, |
| **weight_init[1] |
| ) |
| if use_bias: |
| bias = self.param_init_net.__getattr__(bias_init[0])( |
| [], |
| blob_out + '_gconv_%d_b' % i, |
| shape=[int(dim_out / group)], |
| **bias_init[1] |
| ) |
| else: |
| weight = core.ScopedBlobReference( |
| blob_out + '_gconv_%d_w' % i, self.param_init_net) |
| if use_bias: |
| bias = core.ScopedBlobReference( |
| blob_out + '_gconv_%d_b' % i, self.param_init_net) |
| if use_bias: |
| self.params.extend([weight, bias]) |
| else: |
| self.params.extend([weight]) |
| self.weights.append(weight) |
| if use_bias: |
| self.biases.append(bias) |
| if use_bias: |
| inputs = [weight, bias] |
| else: |
| inputs = [weight] |
| if 'no_bias' in kwargs: |
| del kwargs['no_bias'] |
| conv_blobs.append( |
| splitted_blobs[i].Conv( |
| inputs, |
| blob_out + '_gconv_%d' % i, |
| kernel=kernel, |
| order=self.order, |
| **kwargs |
| ) |
| ) |
| concat, concat_dims = self.net.Concat( |
| conv_blobs, |
| [blob_out, "_" + blob_out + "_concat_dims"], |
| order=self.order |
| ) |
| return concat |
| |
| def _FC_or_packed_FC( |
| self, op_call, blob_in, blob_out, dim_in, dim_out, weight_init=None, |
| bias_init=None, **kwargs |
| ): |
| """FC""" |
| weight_init = weight_init or ('XavierFill', {}) |
| bias_init = bias_init or ('ConstantFill', {}) |
| blob_out = blob_out or self.net.NextName() |
| if self.init_params: |
| weight = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_w', |
| shape=[dim_out, dim_in], |
| **weight_init[1] |
| ) |
| bias = self.param_init_net.__getattr__(bias_init[0])( |
| [], |
| blob_out + '_b', |
| shape=[dim_out, ], |
| **bias_init[1] |
| ) |
| else: |
| weight = core.ScopedBlobReference( |
| blob_out + '_w', self.param_init_net) |
| bias = core.ScopedBlobReference( |
| blob_out + '_b', self.param_init_net) |
| |
| if 'freeze_bias' in kwargs: |
| self.params.extend([weight]) |
| else: |
| self.params.extend([weight, bias]) |
| |
| self.weights.append(weight) |
| self.biases.append(bias) |
| return op_call([blob_in, weight, bias], blob_out, **kwargs) |
| |
| def FC(self, *args, **kwargs): |
| return self._FC_or_packed_FC(self.net.FC, *args, **kwargs) |
| |
| def PackedFC(self, *args, **kwargs): |
| return self._FC_or_packed_FC(self.net.PackedFC, *args, **kwargs) |
| |
| def FC_Decomp( |
| self, blob_in, blob_out, dim_in, dim_out, |
| rank_approx=5, weight_init=None, |
| bias_init=None, **kwargs |
| ): |
| """FC_Decomp version |
| Here we assume that the rank of original input is bigger than 5. |
| """ |
| weight_init = weight_init if weight_init else ('XavierFill', {}) |
| bias_init = bias_init if bias_init else ('ConstantFill', {}) |
| blob_out = blob_out or self.net.NextName() |
| u = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_u', |
| shape=[dim_out, rank_approx], |
| **weight_init[1] |
| ) |
| v = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_v', |
| shape=[dim_in, rank_approx], |
| **weight_init[1] |
| ) |
| bias = self.param_init_net.__getattr__(bias_init[0])( |
| [], |
| blob_out + '_b', |
| shape=[dim_out, ], |
| **bias_init[1] |
| ) |
| self.params.extend([u, v, bias]) |
| return self.net.FC_Decomp([blob_in, u, v, bias], blob_out, **kwargs) |
| |
| def FC_Prune( |
| self, blob_in, blob_out, dim_in, dim_out, |
| weight_init=None, bias_init=None, mask_init=None, |
| threshold=0.00001, need_compress_rate=False, |
| comp_lb=0.05, |
| **kwargs |
| ): |
| """FC_Prune version |
| Runnable so far. Great!:) |
| """ |
| weight_init = weight_init if weight_init else ('XavierFill', {}) |
| bias_init = bias_init if bias_init else ('ConstantFill', {}) |
| mask_init = mask_init if mask_init else ('ConstantFill', {}) |
| blob_out = blob_out or self.net.NextName() |
| compress_rate = blob_out + '_compress_rate' |
| if self.init_params: |
| compress_lb = self.param_init_net.ConstantFill( |
| [], |
| blob_out + '_lb', |
| shape=[1], |
| value=comp_lb |
| ) |
| weight = self.param_init_net.__getattr__(weight_init[0])( |
| [], |
| blob_out + '_w', |
| shape=[dim_out, dim_in], |
| **weight_init[1] |
| ) |
| mask = self.param_init_net.ConstantFill( |
| [], |
| blob_out + '_m', |
| shape=[dim_out, dim_in], |
| value=1.0 |
| ) |
| ag_dw = self.param_init_net.__getattr__(mask_init[0])( |
| [], |
| blob_out + '_ag_dw', |
| shape=[dim_out, dim_in], |
| **mask_init[1] |
| ) |
| bias = self.param_init_net.__getattr__(bias_init[0])( |
| [], |
| blob_out + '_b', |
| shape=[dim_out, ], |
| **bias_init[1] |
| ) |
| mask_seq = self.param_init_net.__getattr__(mask_init[0])( |
| [], |
| blob_out + '_mask_seq', |
| shape=[dim_out, dim_in], |
| **mask_init[1] |
| ) |
| thres = self.param_init_net.ConstantFill( |
| [], |
| blob_out + '_thres', |
| shape=[1], |
| value=threshold |
| ) |
| else: |
| compress_lb = core.ScopedBlobReference( |
| blob_out + '_lb', self.param_init_net) |
| weight = core.ScopedBlobReference( |
| blob_out + '_w', self.param_init_net) |
| bias = core.ScopedBlobReference( |
| blob_out + '_b', self.param_init_net) |
| mask = core.ScopedBlobReference( |
| blob_out + '_m', self.param_init_net) |
| ag_dw = core.ScopedBlobReference( |
| blob_out + '_ag_dw', self.param_init_net) |
| mask_seq = core.ScopedBlobReference( |
| blob_out + '_mask_seq', self.param_init_net) |
| thres = core.ScopedBlobReference( |
| blob_out + '_thres', self.param_init_net) |
| |
| self.params.extend([weight, bias]) |
| if need_compress_rate: |
| return self.net.FC_Prune([blob_in, weight, mask, |
| bias, ag_dw, mask_seq, |
| thres, compress_lb], |
| [blob_out, compress_rate], **kwargs) |
| else: |
| return self.net.FC_Prune([blob_in, weight, mask, |
| bias, ag_dw, mask_seq, |
| thres, compress_lb], |
| blob_out, **kwargs) |
| |
| def FC_Sparse( |
| self, blob_in, blob_out, w_csr, iw, jw, bias, |
| **kwargs |
| ): |
| """FC_Sparse: Only takes in alocated weights""" |
| if not (w_csr and iw and jw and bias): |
| print("Warning...") |
| self.params.extend([w_csr, iw, jw, bias]) |
| return self.net.FC_Sparse([blob_in, w_csr, iw, jw, bias], |
| blob_out, **kwargs) |
| |
| def LRN(self, blob_in, blob_out, **kwargs): |
| """LRN""" |
| return self.net.LRN( |
| blob_in, |
| [blob_out, "_" + blob_out + "_scale"], |
| order=self.order, |
| **kwargs |
| )[0] |
| |
| def Dropout(self, blob_in, blob_out, **kwargs): |
| """Dropout""" |
| return self.net.Dropout( |
| blob_in, [blob_out, "_" + blob_out + "_mask"], **kwargs |
| )[0] |
| |
| def MaxPool(self, blob_in, blob_out, **kwargs): |
| """Max pooling""" |
| if self.use_cudnn: |
| kwargs['engine'] = 'CUDNN' |
| return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs) |
| |
| def AveragePool(self, blob_in, blob_out, **kwargs): |
| """Average pooling""" |
| if self.use_cudnn: |
| kwargs['engine'] = 'CUDNN' |
| return self.net.AveragePool( |
| blob_in, |
| blob_out, |
| order=self.order, |
| **kwargs |
| ) |
| |
| def Concat(self, blobs_in, blob_out, **kwargs): |
| """Depth Concat.""" |
| return self.net.Concat( |
| blobs_in, |
| [blob_out, "_" + blob_out + "_concat_dims"], |
| order=self.order, |
| **kwargs |
| )[0] |
| |
| def DepthConcat(self, blobs_in, blob_out, **kwargs): |
| """The old depth concat function - we should move to use concat.""" |
| print("DepthConcat is deprecated. use Concat instead.") |
| return self.Concat(blobs_in, blob_out, **kwargs) |
| |
| def PRelu(self, blob_in, blob_out, num_channels=1, slope_init=None, |
| **kwargs): |
| """PRelu""" |
| slope_init = ( |
| slope_init if slope_init else ('ConstantFill', {'value': 0.25})) |
| if self.init_params: |
| slope = self.param_init_net.__getattr__(slope_init[0])( |
| [], |
| blob_out + '_slope', |
| shape=[num_channels], |
| **slope_init[1] |
| ) |
| else: |
| slope = core.ScopedBlobReference( |
| blob_out + '_slope', self.param_init_net) |
| |
| self.params.extend([slope]) |
| |
| return self.net.PRelu([blob_in, slope], [blob_out]) |
| |
| def Relu(self, blob_in, blob_out, **kwargs): |
| """Relu.""" |
| if self.use_cudnn: |
| kwargs['engine'] = 'CUDNN' |
| return self.net.Relu(blob_in, blob_out, order=self.order, **kwargs) |
| |
| def Transpose(self, blob_in, blob_out, **kwargs): |
| """Transpose.""" |
| return self.net.Transpose(blob_in, blob_out, **kwargs) |
| |
| def Sum(self, blob_in, blob_out, **kwargs): |
| """Sum""" |
| return self.net.Sum(blob_in, blob_out, **kwargs) |
| |
| def InstanceNorm(self, blob_in, blob_out, dim_in, **kwargs): |
| blob_out = blob_out or self.net.NextName() |
| # Input: input, scale, bias |
| # Output: output, saved_mean, saved_inv_std |
| # scale: initialize with ones |
| # bias: initialize with zeros |
| |
| def init_blob(value, suffix): |
| return self.param_init_net.ConstantFill( |
| [], blob_out + "_" + suffix, shape=[dim_in], value=value) |
| scale, bias = init_blob(1.0, "s"), init_blob(0.0, "b") |
| |
| self.params.extend([scale, bias]) |
| self.weights.append(scale) |
| self.biases.append(bias) |
| blob_outs = [blob_out, blob_out + "_sm", blob_out + "_siv"] |
| if 'is_test' in kwargs and kwargs['is_test']: |
| blob_outputs = self.net.InstanceNorm( |
| [blob_in, scale, bias], [blob_out], |
| order=self.order, **kwargs) |
| return blob_outputs |
| else: |
| blob_outputs = self.net.InstanceNorm( |
| [blob_in, scale, bias], blob_outs, |
| order=self.order, **kwargs) |
| # Return the output |
| return blob_outputs[0] |
| |
| def SpatialBN(self, blob_in, blob_out, dim_in, **kwargs): |
| blob_out = blob_out or self.net.NextName() |
| # Input: input, scale, bias, est_mean, est_inv_var |
| # Output: output, running_mean, running_inv_var, saved_mean, |
| # saved_inv_var |
| # scale: initialize with ones |
| # bias: initialize with zeros |
| # est mean: zero |
| # est var: ones |
| |
| def init_blob(value, suffix): |
| return self.param_init_net.ConstantFill( |
| [], blob_out + "_" + suffix, shape=[dim_in], value=value) |
| |
| if self.init_params: |
| scale, bias = init_blob(1.0, "s"), init_blob(0.0, "b") |
| running_mean = init_blob(0.0, "rm") |
| running_inv_var = init_blob(1.0, "riv") |
| else: |
| scale = core.ScopedBlobReference( |
| blob_out + '_s', self.param_init_net) |
| bias = core.ScopedBlobReference( |
| blob_out + '_b', self.param_init_net) |
| running_mean = core.ScopedBlobReference( |
| blob_out + '_rm', self.param_init_net) |
| running_inv_var = core.ScopedBlobReference( |
| blob_out + '_riv', self.param_init_net) |
| |
| self.params.extend([scale, bias]) |
| self.computed_params.extend([running_mean, running_inv_var]) |
| self.weights.append(scale) |
| self.biases.append(bias) |
| blob_outs = [blob_out, running_mean, running_inv_var, |
| blob_out + "_sm", blob_out + "_siv"] |
| if 'is_test' in kwargs and kwargs['is_test']: |
| blob_outputs = self.net.SpatialBN( |
| [blob_in, scale, bias, blob_outs[1], blob_outs[2]], [blob_out], |
| order=self.order, **kwargs) |
| return blob_outputs |
| else: |
| blob_outputs = self.net.SpatialBN( |
| [blob_in, scale, bias, blob_outs[1], blob_outs[2]], blob_outs, |
| order=self.order, **kwargs) |
| # Return the output |
| return blob_outputs[0] |
| |
| def Iter(self, blob_out, **kwargs): |
| if 'device_option' in kwargs: |
| del kwargs['device_option'] |
| self.param_init_net.ConstantFill( |
| [], blob_out, shape=[1], value=0, dtype=core.DataType.INT64, |
| device_option=core.DeviceOption(caffe2_pb2.CPU, 0), |
| **kwargs) |
| return self.net.Iter(blob_out, blob_out, **kwargs) |
| |
| def Accuracy(self, blob_in, blob_out, **kwargs): |
| dev = kwargs['device_option'] if 'device_option' in kwargs \ |
| else scope.CurrentDeviceScope() |
| is_cpu = dev is None or dev.device_type == caffe2_pb2.CPU |
| |
| # We support top_k > 1 only on CPU |
| if not is_cpu and 'top_k' in kwargs and kwargs['top_k'] > 1: |
| pred_host = self.net.CopyGPUToCPU(blob_in[0], blob_in[0] + "_host") |
| label_host = self.net.CopyGPUToCPU(blob_in[1], blob_in[1] + "_host") |
| |
| # Now use the Host version of the accuracy op |
| self.net.Accuracy([pred_host, label_host], |
| blob_out, |
| device_option=core.DeviceOption(caffe2_pb2.CPU, 0), |
| **kwargs) |
| else: |
| self.net.Accuracy(blob_in, blob_out) |
| |
| def PadImage( |
| self, blob_in, blob_out, **kwargs |
| ): |
| self.net.PadImage(blob_in, blob_out, **kwargs) |
| |
| @property |
| def XavierInit(self): |
| return ('XavierFill', {}) |
| |
| def ConstantInit(self, value): |
| return ('ConstantFill', dict(value=value)) |
| |
| @property |
| def MSRAInit(self): |
| return ('MSRAFill', {}) |
| |
| @property |
| def ZeroInit(self): |
| return ('ConstantFill', {}) |
| |
| def AddWeightDecay(self, weight_decay): |
| """Adds a decay to weights in the model. |
| |
| This is a form of L2 regularization. |
| |
| Args: |
| weight_decay: strength of the regularization |
| """ |
| if weight_decay <= 0.0: |
| return |
| wd = self.param_init_net.ConstantFill([], 'wd', shape=[1], |
| value=weight_decay) |
| ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0) |
| for param in self.GetWeights(): |
| # Equivalent to: grad += wd * param |
| grad = self.param_to_grad[param] |
| self.net.WeightedSum( |
| [grad, ONE, param, wd], |
| grad, |
| ) |
| |
| @property |
| def CPU(self): |
| device_option = caffe2_pb2.DeviceOption() |
| device_option.device_type = caffe2_pb2.CPU |
| return device_option |
| |
| @property |
| def GPU(self, gpu_id=0): |
| device_option = caffe2_pb2.DeviceOption() |
| device_option.device_type = caffe2_pb2.CUDA |
| device_option.cuda_gpu_id = gpu_id |
| return device_option |
| |
| def LSTM(self, input_blob, seq_lengths, initial_states, dim_in, dim_out, |
| scope=None): |
| def s(name): |
| # We have to manually scope due to our internal/external blob |
| # relationships. |
| scope_name = scope or str(input_blob) |
| return "{}/{}".format(str(scope_name), str(name)) |
| |
| (hidden_input_blob, cell_input_blob) = initial_states |
| |
| input_blob = self.FC(input_blob, s("i2h"), |
| dim_in=dim_in, dim_out=4 * dim_out, axis=2) |
| |
| step_net = CNNModelHelper(name="LSTM") |
| step_net.Proto().external_input.extend([ |
| str(seq_lengths), |
| "input_t", |
| "timestep", |
| "hidden_t_prev", |
| "cell_t_prev", |
| s("gates_t_w"), |
| s("gates_t_b"), |
| ]) |
| step_net.Proto().type = "simple" |
| step_net.Proto().external_output.extend( |
| ["hidden_t", "cell_t", s("gates_t")]) |
| step_net.FC("hidden_t_prev", s("gates_t"), |
| dim_in=dim_out, dim_out=4 * dim_out, axis=2) |
| step_net.net.Sum([s("gates_t"), "input_t"], [s("gates_t")]) |
| step_net.net.LSTMUnit( |
| [ |
| "hidden_t_prev", |
| "cell_t_prev", |
| s("gates_t"), |
| str(seq_lengths), |
| "timestep", |
| ], |
| ["hidden_t", "cell_t"], |
| ) |
| |
| links = [ |
| ("hidden_t_prev", s("hidden"), 0), |
| ("hidden_t", s("hidden"), 1), |
| ("cell_t_prev", s("cell"), 0), |
| ("cell_t", s("cell"), 1), |
| ("input_t", str(input_blob), 0), |
| ] |
| link_internal, link_external, link_offset = zip(*links) |
| |
| # # Initialize params for step net in the parent net |
| # for op in step_net.param_init_net.Proto().op: |
| |
| # Set up the backward links |
| backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass( |
| step_net.Proto().op, |
| {"hidden_t": "hidden_t_grad", "cell_t": "cell_t_grad"}) |
| backward_mapping = {str(k): str(v) for k, v |
| in backward_mapping.items()} |
| backward_step_net = core.Net("LSTMBackward") |
| del backward_step_net.Proto().op[:] |
| backward_step_net.Proto().op.extend(backward_ops) |
| |
| backward_links = [ |
| ("hidden_t_prev_grad", s("hidden_grad"), 0), |
| ("hidden_t_grad", s("hidden_grad"), 1), |
| ("cell_t_prev_grad", s("cell_grad"), 0), |
| ("cell_t_grad", s("cell_grad"), 1), |
| (s("gates_t_grad"), str(input_blob) + "_grad", 0), |
| ] |
| |
| backward_link_internal, backward_link_external, \ |
| backward_link_offset = zip(*backward_links) |
| |
| backward_step_net.Proto().external_input.extend( |
| ["hidden_t_grad", "cell_t_grad"]) |
| backward_step_net.Proto().external_input.extend( |
| step_net.Proto().external_input) |
| backward_step_net.Proto().external_input.extend( |
| step_net.Proto().external_output) |
| |
| inputs = map(str, [input_blob, seq_lengths, |
| s("gates_t_w"), s("gates_t_b"), |
| hidden_input_blob, cell_input_blob]) |
| recurrent_inputs = [str(hidden_input_blob), str(cell_input_blob)] |
| |
| output, _, _, hidden_state, cell_state, _ = self.net.RecurrentNetwork( |
| inputs, |
| [s("output"), s("hidden"), s("cell"), |
| s("hidden_output"), s("cell_output"), s("step_workspaces")], |
| param=map(inputs.index, step_net.params), |
| alias_src=[s("hidden"), s("hidden"), s("cell")], |
| alias_dst=[s("output"), s("hidden_output"), s("cell_output")], |
| alias_offset=[1, -1, -1], |
| recurrent_states=[s("hidden"), s("cell")], |
| initial_recurrent_state_ids=map(inputs.index, recurrent_inputs), |
| link_internal=link_internal, |
| link_external=link_external, |
| link_offset=link_offset, |
| backward_link_internal=backward_link_internal, |
| backward_link_external=backward_link_external, |
| backward_link_offset=backward_link_offset, |
| step_net=str(step_net.Proto()), |
| backward_step_net=str(backward_step_net.Proto()), |
| timestep="timestep", |
| outputs_with_grads=[0], |
| ) |
| self.param_init_net.Proto().op.extend( |
| step_net.param_init_net.Proto().op) |
| self.params += step_net.params |
| for p in step_net.params: |
| if str(p) in backward_mapping: |
| self.param_to_grad[p] = backward_mapping[str(p)] |
| self.weights += step_net.weights |
| self.biases += step_net.biases |
| return output, hidden_state, cell_state |