| # Copyright (c) 2016-present, Facebook, Inc. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ############################################################################## |
| |
| ## @package uniform_sampling |
| # Module caffe2.python.layers.uniform_sampling |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import numpy as np |
| |
| from caffe2.python import core, schema |
| from caffe2.python.layers.layers import ModelLayer |
| |
| |
| class UniformSampling(ModelLayer): |
| """ |
| Uniform sampling `num_samples - len(input_record)` unique elements from the |
| range [0, num_elements). `samples` is the concatenation of input_record and |
| the samples. input_record is expected to be unique. |
| """ |
| |
| def __init__( |
| self, |
| model, |
| input_record, |
| num_samples, |
| num_elements, |
| name='uniform_sampling', |
| **kwargs |
| ): |
| super(UniformSampling, self).__init__( |
| model, name, input_record, **kwargs |
| ) |
| |
| assert num_elements > num_samples > 0 |
| assert isinstance(input_record, schema.Scalar) |
| |
| self.num_elements = num_elements |
| |
| num_examples_init = ('GivenTensorInt64Fill', |
| {'values': [num_samples]}) |
| self.num_samples = self.create_param(param_name='num_examples', |
| shape=(1,), |
| initializer=num_examples_init, |
| optimizer=model.NoOptim) |
| |
| sampling_blob_init = ('ConstantFill', |
| {'value': float(num_samples) / num_elements, |
| 'dtype': core.DataType.FLOAT}) |
| self.sampling_prob = self.create_param(param_name='prob', |
| shape=(num_samples,), |
| initializer=sampling_blob_init, |
| optimizer=model.NoOptim) |
| |
| self.output_schema = schema.Struct( |
| ( |
| 'samples', schema.Scalar( |
| np.int32, self.get_next_blob_reference("samples") |
| ) |
| ), |
| ('sampling_prob', schema.Scalar(np.float32, self.sampling_prob)), |
| ) |
| |
| def add_ops(self, net): |
| net.StopGradient(self.sampling_prob, self.sampling_prob) |
| |
| shape = net.Shape([self.input_record()], net.NextScopedBlob("shape")) |
| shape = net.Sub([self.num_samples, shape], shape) |
| samples = net.UniqueUniformFill( |
| [shape, self.input_record()], |
| net.NextScopedBlob("samples_before_concat"), |
| min=0, |
| max=self.num_elements - 1, |
| input_as_shape=True |
| ) |
| |
| net.Concat( |
| [self.input_record(), samples], |
| [self.output_schema.samples(), net.NextScopedBlob("split_info")], |
| axis=0 |
| ) |
| net.StopGradient( |
| self.output_schema.samples(), self.output_schema.samples() |
| ) |