Allow to pass in masks through db (#31676)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31676
Facebook:
Previously we assumed mask is passed in as a tensor which is not feasible for sparse parameter.
Here we allow to pass in the mask through db path which requires the masks to be stored in some db first.
Test Plan: unit tests
Reviewed By: ellie-wen
Differential Revision: D18928753
fbshipit-source-id: 75ca894de0f0dcd64ce17b13652484b3550cbdac
diff --git a/caffe2/python/optimizer.py b/caffe2/python/optimizer.py
index 333793e..cbf8eac 100644
--- a/caffe2/python/optimizer.py
+++ b/caffe2/python/optimizer.py
@@ -520,7 +520,7 @@
sparse_dedup_aggregator=None, rowWise=False, engine='',
lars=None, output_effective_lr=False,
output_effective_lr_and_update=False,
- mask_tensor=None, **kwargs):
+ pruning_options=None, **kwargs):
super(AdagradOptimizer, self).__init__()
self.alpha = alpha
self.epsilon = epsilon
@@ -532,12 +532,43 @@
self.lars = lars
self.output_effective_lr = output_effective_lr
self.output_effective_lr_and_update = output_effective_lr_and_update
- self.mask_tensor = mask_tensor
self.init_kwargs = kwargs
+ self._process_pruning_options(pruning_options)
+
+ def _process_pruning_options(self, pruning_options):
self.use_mask = False
+
+ if pruning_options is None:
+ pruning_options = {}
+ else:
+ assert isinstance(pruning_options, dict), "pruning_options can only "\
+ "be provided as a dictionary, currently: {}".format(pruning_options)
+
+ self.mask_tensor = pruning_options.get("mask_tensor", None)
+ self.mask_db_path = pruning_options.get("mask_db_path", None)
+ self.mask_db_type = pruning_options.get("mask_db_type", None)
+ self.mask_blob_name = pruning_options.get("mask_blob_name", None)
+
if self.mask_tensor is not None:
- assert type(mask_tensor) is np.ndarray, "mask_tensor must be a numpy array!"
+ assert type(self.mask_tensor) is np.ndarray, "mask_tensor must be a numpy array!"
+ assert self.mask_db_path is None, "mask can be provided through either a numpy array "\
+ "or a db path, not both"
+ assert self.mask_db_type is None, "mask can be provided through either a numpy array "\
+ "or a db path, not both"
+ assert self.mask_blob_name is None, "mask can be provided through either a numpy array "\
+ "or a db path, not both"
+ self.use_mask = True
+ if self.mask_db_path is not None or self.mask_db_type is not None\
+ or self.mask_blob_name is not None:
+ assert self.mask_db_path is not None, "when mask is provided through db, "\
+ "db path, db type, and blob name are all needed"
+ assert self.mask_db_type is not None, "when mask is provided through db, "\
+ "db path, db type, and blob name are all needed"
+ assert self.mask_blob_name is not None, "when mask is provided through db, "\
+ "db path, db type, and blob name are all needed"
+ assert self.mask_tensor is None, "mask can be provided through either a numpy array "\
+ "or a db path, not both"
self.use_mask = True
def _run(self, net, param_init_net, param_info):
@@ -622,13 +653,23 @@
)
if self.use_mask is True:
- if not isinstance(grad, core.GradientSlice):
- mask_blob = param_init_net.GivenTensorFill([], [str(param) + "_mask"], values=self.mask_tensor, shape=self.mask_tensor.shape)
+ if self.mask_tensor is not None:
+ if not isinstance(grad, core.GradientSlice):
+ mask_blob = param_init_net.GivenTensorFill([], [str(param) + "_mask"], values=self.mask_tensor, shape=self.mask_tensor.shape)
+ else:
+ self.mask_tensor = self.mask_tensor.astype(np.uint8)
+ mask_blob = param_init_net.GivenTensorBoolFill([], [str(param) + "_mask"], values=self.mask_tensor, shape=self.mask_tensor.shape)
+ mask_blob = param_init_net.Cast(mask_blob, to=core.DataType.UINT8)
+ mask_changed_blob = param_init_net.ConstantFill([], [str(param) + "_mask_changed_blob"], value=False, dtype=core.DataType.BOOL, shape=[1])
+ elif self.mask_db_path is not None or self.mask_db_type is not None\
+ or self.mask_blob_name is not None: # mask is provided through a db file
+ mask_blob = param_init_net.Load(
+ [], self.mask_blob_name, db=self.mask_db_path, db_type=self.mask_db_type, absolute_path=True
+ )
+ if isinstance(grad, core.GradientSlice):
+ mask_changed_blob = param_init_net.ConstantFill([], [str(param) + "_mask_changed_blob"], value=False, dtype=core.DataType.BOOL, shape=[1])
else:
- self.mask_tensor = self.mask_tensor.astype(np.uint8)
- mask_blob = param_init_net.GivenTensorBoolFill([], [str(param) + "_mask"], values=self.mask_tensor, shape=self.mask_tensor.shape)
- mask_blob = param_init_net.Cast(mask_blob, to=core.DataType.UINT8)
- mask_changed_blob = param_init_net.ConstantFill([], [str(param) + "_mask_changed_blob"], value=False, dtype=core.DataType.BOOL, shape=[1])
+ raise NotImplementedError("If mask is used, it needs to be provided through a numpy array or a db file")
self._aux_params.local.append(param_squared_sum)