Raise an error when pickling or copying FlagValues() (fixed version)

Deep copying flags (with copy.deepcopy) is still allowed. Pickling and shallow
are now prohibited, because the notion of a shallow copy isn't entirely well
defined -- should an unpickled flag or shallow copy link back to the original
flag value?

FlagValues() already cannot be successfully serialized/deserialized with
pickle. But the error message is unrelated and raised when attempting to
*load* pickled FlagValues instances, not when saving them:

  >>> from absl import flags
  >>> import pickle
  >>> dumped = pickle.dumps(flags.FLAGS)  # no error
  >>> pickle.loads(dumped)
  Traceback (most recent call last)
  <ipython-input-5-5a8322d34219> in <module>()
  ----> 1 pickle.loads(dumped)

  /usr/lib/python2.7/pickle.pyc in loads(str)
     1386 def loads(str):
     1387     file = StringIO(str)
  -> 1388     return Unpickler(file).load()
     1389
     1390 # Doctest

  /usr/lib/python2.7/pickle.pyc in load(self)
      862             while 1:
      863                 key = read(1)
  --> 864                 dispatch[key](self)
      865         except _Stop, stopinst:
      866             return stopinst.value

  /usr/lib/python2.7/pickle.pyc in load_build(self)
     1219         state = stack.pop()
     1220         inst = stack[-1]
  -> 1221         setstate = getattr(inst, "__setstate__", None)
     1222         if setstate:
     1223             setstate(state)

  /usr/local/lib/python2.7/dist-packages/absl/flags/_flagvalues.pyc in __getattr__(self, name)
      466   def __getattr__(self, name):
      467     """Retrieves the 'value' attribute of the flag --name."""
  --> 468     fl = self._flags()
      469     if name not in fl:
      470       raise AttributeError(name)

  /usr/local/lib/python2.7/dist-packages/absl/flags/_flagvalues.pyc in _flags(self)
      139
      140   def _flags(self):
  --> 141     return self.__dict__['__flags']
      142
      143   def flags_by_module_dict(self):

  KeyError: '__flags'

This change causes an error to be raised earlier (as part of serialization
rather than deserialization) and with a better error message, e.g.,

  >>> pickle.dumps(flags.FLAGS)
  TypeError: can't pickle FlagValues

PiperOrigin-RevId: 222298474
diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py
index 8ef6cde..856b47f 100644
--- a/absl/flags/_flag.py
+++ b/absl/flags/_flag.py
@@ -22,6 +22,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import copy
 import functools
 
 from absl.flags import _argument_parser
@@ -125,6 +126,18 @@
       return id(self) < id(other)
     return NotImplemented
 
+  def __getstate__(self):
+    raise TypeError("can't pickle Flag objects")
+
+  def __copy__(self):
+    raise TypeError('%s does not support shallow copies. '
+                    'Use copy.deepcopy instead.' % type(self).__name__)
+
+  def __deepcopy__(self, memo):
+    result = object.__new__(type(self))
+    result.__dict__ = copy.deepcopy(self.__dict__, memo)
+    return result
+
   def _get_parsed_value_as_string(self, value):
     """Returns parsed flag value as string."""
     if value is None:
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 9ad9eca..d7f454b 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -22,6 +22,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import copy
 import itertools
 import logging
 import os
@@ -635,6 +636,18 @@
     self._assert_all_validators()
     return [program_name] + unparsed_args
 
+  def __getstate__(self):
+    raise TypeError("can't pickle FlagValues")
+
+  def __copy__(self):
+    raise TypeError('FlagValues does not support shallow copies. '
+                    'Use absl.testing.flagsaver or copy.deepcopy instead.')
+
+  def __deepcopy__(self, memo):
+    result = object.__new__(type(self))
+    result.__dict__.update(copy.deepcopy(self.__dict__, memo))
+    return result
+
   def _set_is_retired_flag_func(self, is_retired_flag_func):
     """Sets a function for checking retired flags.
 
diff --git a/absl/flags/tests/_flag_test.py b/absl/flags/tests/_flag_test.py
index f72ed6b..522711c 100644
--- a/absl/flags/tests/_flag_test.py
+++ b/absl/flags/tests/_flag_test.py
@@ -21,6 +21,9 @@
 from __future__ import division
 from __future__ import print_function
 
+import copy
+import pickle
+
 from absl._enum_module import enum
 from absl.flags import _argument_parser
 from absl.flags import _exceptions
@@ -67,6 +70,24 @@
     self.flag._set_default('orange')
     self.assertEqual('apple', self.flag.value)
 
+  def test_pickle(self):
+    with self.assertRaisesRegexp(TypeError, "can't pickle Flag objects"):
+      pickle.dumps(self.flag)
+
+  def test_copy(self):
+    self.flag.value = 'orange'
+
+    with self.assertRaisesRegexp(
+        TypeError, 'Flag does not support shallow copies'):
+      copy.copy(self.flag)
+
+    flag2 = copy.deepcopy(self.flag)
+    self.assertEqual(flag2.value, 'orange')
+
+    flag2.value = 'mango'
+    self.assertEqual(flag2.value, 'mango')
+    self.assertEqual(self.flag.value, 'orange')
+
 
 class BooleanFlagTest(parameterized.TestCase):
 
diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py
index e885979..72fb4d4 100644
--- a/absl/flags/tests/_flagvalues_test.py
+++ b/absl/flags/tests/_flagvalues_test.py
@@ -18,6 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import copy
+import pickle
 import types
 import unittest
 
@@ -341,6 +343,27 @@
     self.assertEqual(3, len(fv))
     self.assertTrue(fv)
 
+  def test_pickle(self):
+    fv = _flagvalues.FlagValues()
+    with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"):
+      pickle.dumps(fv)
+
+  def test_copy(self):
+    fv = _flagvalues.FlagValues()
+    _defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
+    fv(['', '--answer=1'])
+
+    with self.assertRaisesRegexp(
+        TypeError, 'FlagValues does not support shallow copies'):
+      copy.copy(fv)
+
+    fv2 = copy.deepcopy(fv)
+    self.assertEqual(fv2.answer, 1)
+
+    fv2.answer = 42
+    self.assertEqual(fv2.answer, 42)
+    self.assertEqual(fv.answer, 1)
+
   def test_conflicting_flags(self):
     fv = _flagvalues.FlagValues()
     with self.assertRaises(_exceptions.FlagNameConflictsWithMethodError):