Raise an error when pickling or copying FlagValues() (fixed version)
Deep copying flags (with copy.deepcopy) is still allowed. Pickling and shallow
are now prohibited, because the notion of a shallow copy isn't entirely well
defined -- should an unpickled flag or shallow copy link back to the original
flag value?
FlagValues() already cannot be successfully serialized/deserialized with
pickle. But the error message is unrelated and raised when attempting to
*load* pickled FlagValues instances, not when saving them:
>>> from absl import flags
>>> import pickle
>>> dumped = pickle.dumps(flags.FLAGS) # no error
>>> pickle.loads(dumped)
Traceback (most recent call last)
<ipython-input-5-5a8322d34219> in <module>()
----> 1 pickle.loads(dumped)
/usr/lib/python2.7/pickle.pyc in loads(str)
1386 def loads(str):
1387 file = StringIO(str)
-> 1388 return Unpickler(file).load()
1389
1390 # Doctest
/usr/lib/python2.7/pickle.pyc in load(self)
862 while 1:
863 key = read(1)
--> 864 dispatch[key](self)
865 except _Stop, stopinst:
866 return stopinst.value
/usr/lib/python2.7/pickle.pyc in load_build(self)
1219 state = stack.pop()
1220 inst = stack[-1]
-> 1221 setstate = getattr(inst, "__setstate__", None)
1222 if setstate:
1223 setstate(state)
/usr/local/lib/python2.7/dist-packages/absl/flags/_flagvalues.pyc in __getattr__(self, name)
466 def __getattr__(self, name):
467 """Retrieves the 'value' attribute of the flag --name."""
--> 468 fl = self._flags()
469 if name not in fl:
470 raise AttributeError(name)
/usr/local/lib/python2.7/dist-packages/absl/flags/_flagvalues.pyc in _flags(self)
139
140 def _flags(self):
--> 141 return self.__dict__['__flags']
142
143 def flags_by_module_dict(self):
KeyError: '__flags'
This change causes an error to be raised earlier (as part of serialization
rather than deserialization) and with a better error message, e.g.,
>>> pickle.dumps(flags.FLAGS)
TypeError: can't pickle FlagValues
PiperOrigin-RevId: 222298474
diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py
index 8ef6cde..856b47f 100644
--- a/absl/flags/_flag.py
+++ b/absl/flags/_flag.py
@@ -22,6 +22,7 @@
from __future__ import division
from __future__ import print_function
+import copy
import functools
from absl.flags import _argument_parser
@@ -125,6 +126,18 @@
return id(self) < id(other)
return NotImplemented
+ def __getstate__(self):
+ raise TypeError("can't pickle Flag objects")
+
+ def __copy__(self):
+ raise TypeError('%s does not support shallow copies. '
+ 'Use copy.deepcopy instead.' % type(self).__name__)
+
+ def __deepcopy__(self, memo):
+ result = object.__new__(type(self))
+ result.__dict__ = copy.deepcopy(self.__dict__, memo)
+ return result
+
def _get_parsed_value_as_string(self, value):
"""Returns parsed flag value as string."""
if value is None:
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 9ad9eca..d7f454b 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -22,6 +22,7 @@
from __future__ import division
from __future__ import print_function
+import copy
import itertools
import logging
import os
@@ -635,6 +636,18 @@
self._assert_all_validators()
return [program_name] + unparsed_args
+ def __getstate__(self):
+ raise TypeError("can't pickle FlagValues")
+
+ def __copy__(self):
+ raise TypeError('FlagValues does not support shallow copies. '
+ 'Use absl.testing.flagsaver or copy.deepcopy instead.')
+
+ def __deepcopy__(self, memo):
+ result = object.__new__(type(self))
+ result.__dict__.update(copy.deepcopy(self.__dict__, memo))
+ return result
+
def _set_is_retired_flag_func(self, is_retired_flag_func):
"""Sets a function for checking retired flags.
diff --git a/absl/flags/tests/_flag_test.py b/absl/flags/tests/_flag_test.py
index f72ed6b..522711c 100644
--- a/absl/flags/tests/_flag_test.py
+++ b/absl/flags/tests/_flag_test.py
@@ -21,6 +21,9 @@
from __future__ import division
from __future__ import print_function
+import copy
+import pickle
+
from absl._enum_module import enum
from absl.flags import _argument_parser
from absl.flags import _exceptions
@@ -67,6 +70,24 @@
self.flag._set_default('orange')
self.assertEqual('apple', self.flag.value)
+ def test_pickle(self):
+ with self.assertRaisesRegexp(TypeError, "can't pickle Flag objects"):
+ pickle.dumps(self.flag)
+
+ def test_copy(self):
+ self.flag.value = 'orange'
+
+ with self.assertRaisesRegexp(
+ TypeError, 'Flag does not support shallow copies'):
+ copy.copy(self.flag)
+
+ flag2 = copy.deepcopy(self.flag)
+ self.assertEqual(flag2.value, 'orange')
+
+ flag2.value = 'mango'
+ self.assertEqual(flag2.value, 'mango')
+ self.assertEqual(self.flag.value, 'orange')
+
class BooleanFlagTest(parameterized.TestCase):
diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py
index e885979..72fb4d4 100644
--- a/absl/flags/tests/_flagvalues_test.py
+++ b/absl/flags/tests/_flagvalues_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import copy
+import pickle
import types
import unittest
@@ -341,6 +343,27 @@
self.assertEqual(3, len(fv))
self.assertTrue(fv)
+ def test_pickle(self):
+ fv = _flagvalues.FlagValues()
+ with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"):
+ pickle.dumps(fv)
+
+ def test_copy(self):
+ fv = _flagvalues.FlagValues()
+ _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
+ fv(['', '--answer=1'])
+
+ with self.assertRaisesRegexp(
+ TypeError, 'FlagValues does not support shallow copies'):
+ copy.copy(fv)
+
+ fv2 = copy.deepcopy(fv)
+ self.assertEqual(fv2.answer, 1)
+
+ fv2.answer = 42
+ self.assertEqual(fv2.answer, 42)
+ self.assertEqual(fv.answer, 1)
+
def test_conflicting_flags(self):
fv = _flagvalues.FlagValues()
with self.assertRaises(_exceptions.FlagNameConflictsWithMethodError):