Add `@absltest.skipThisClass` to skip specific classes during testing.

This decorator marks a test in a way that it will be skipped, but none of its subclasses are. Suggested usage is for where you want to share functionality between tests, by having an 'abstract' base class:

```
@absltest.skipThisClass
class _BaseTestCase(absltest.TestCase):
  def test_foo(self):
    self.assertEqual(self.object_under_test.method()

class FooTest(_BaseTestCase):
  def setUp(self):
    self.object_under_test = Foo()

class BarTest(_BaseTestCase):
  def setUp(self):
    self.object_under_test = Bar()
```

There are alternatives, but they have drawbacks:

 * Having `_BaseTestCase` subclass object, and `FooTest` multiple-inherit from both `absltest.TestCase` and `_BaseTestCase`. However, this ends up being problematic for type checking.
 * Repeating the same logic in `absltest.skipThisClass` within `setUpClass` for every class to skip. However, that is repetitive logic that is best put into a utility function.

While `skipThisClass` is similar to `@unittest.skip`, it has an important distinction: regular `skip` will skip the decorated class and all subclasses;
`skipThisClass` only skips the decorated class, allowing base classes to be
correctly skipped while sub-classes are run as tests.

PiperOrigin-RevId: 367965048
Change-Id: Ie050c43c7f2e5dbc5af731171259c27084d1ba12
diff --git a/absl/CHANGELOG.md b/absl/CHANGELOG.md
index 7a1de2f..b5cfd28 100644
--- a/absl/CHANGELOG.md
+++ b/absl/CHANGELOG.md
@@ -9,6 +9,9 @@
 ### Added
 
 *   (app) Type annotations for public `app` interfaces.
+*   (testing) Added new decorator `@absltest.skipThisClass` to indicate a class
+    contains shared functionality to be used as a base class for other
+    TestCases, and therefore should be skipped.
 
 ### Changed
 
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index fbce512..3bf77d9 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -2142,6 +2142,81 @@
   return False
 
 
+def skipThisClass(reason):
+  # type: (Text) -> Callable[[_T], _T]
+  """Skip tests in the decorated TestCase, but not any of its subclasses.
+
+  This decorator indicates that this class should skip all its tests, but not
+  any of its subclasses. Useful for if you want to share testMethod or setUp
+  implementations between a number of concrete testcase classes.
+
+  Example usage, showing how you can share some common test methods between
+  subclasses. In this example, only 'BaseTest' will be marked as skipped, and
+  not RealTest or SecondRealTest:
+
+    @absltest.skipThisClass("Shared functionality")
+    class BaseTest(absltest.TestCase):
+      def test_simple_functionality(self):
+        self.assertEqual(self.system_under_test.method(), 1)
+
+    class RealTest(BaseTest):
+      def setUp(self):
+        super().setUp()
+        self.system_under_test = MakeSystem(argument)
+
+      def test_specific_behavior(self):
+        ...
+
+    class SecondRealTest(BaseTest):
+      def setUp(self):
+        super().setUp()
+        self.system_under_test = MakeSystem(other_arguments)
+
+      def test_other_behavior(self):
+        ...
+
+  Args:
+    reason: The reason we have a skip in place. For instance: 'shared test
+      methods' or 'shared assertion methods'.
+
+  Returns:
+    Decorator function that will cause a class to be skipped.
+  """
+  if isinstance(reason, type):
+    raise TypeError('Got {!r}, expected reason as string'.format(reason))
+
+  def _skip_class(test_case_class):
+    if not issubclass(test_case_class, unittest.TestCase):
+      raise TypeError(
+          'Decorating {!r}, expected TestCase subclass'.format(test_case_class))
+
+    # Only shadow the setUpClass method if it is directly defined. If it is
+    # in the parent class we invoke it via a super() call instead of holding
+    # a reference to it.
+    shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None)
+
+    @classmethod
+    def replacement_setupclass(cls, *args, **kwargs):
+      # Skip this class if it is the one that was decorated with @skipThisClass
+      if cls is test_case_class:
+        raise SkipTest(reason)
+      if shadowed_setupclass:
+        # Pass along `cls` so the MRO chain doesn't break.
+        # The original method is a `classmethod` descriptor, which can't
+        # be directly called, but `__func__` has the underlying function.
+        return shadowed_setupclass.__func__(cls, *args, **kwargs)
+      else:
+        # Because there's no setUpClass() defined directly on test_case_class,
+        # we call super() ourselves to continue execution of the inheritance
+        # chain.
+        return super(test_case_class, cls).setUpClass(*args, **kwargs)
+
+    test_case_class.setUpClass = replacement_setupclass
+    return test_case_class
+
+  return _skip_class
+
+
 class TestLoader(unittest.TestLoader):
   """A test loader which supports common test features.
 
diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py
index f7abd39..0ac3009 100644
--- a/absl/testing/tests/absltest_test.py
+++ b/absl/testing/tests/absltest_test.py
@@ -28,6 +28,7 @@
 import subprocess
 import sys
 import tempfile
+import unittest
 
 from absl.testing import _bazelize_command
 from absl.testing import absltest
@@ -2191,6 +2192,201 @@
     self.run_tempfile_helper('OFF', expected)
 
 
+class SkipClassTest(absltest.TestCase):
+
+  def test_incorrect_decorator_call(self):
+    with self.assertRaises(TypeError):
+
+      @absltest.skipThisClass  # pylint: disable=unused-variable
+      class Test(absltest.TestCase):
+        pass
+
+  def test_incorrect_decorator_subclass(self):
+    with self.assertRaises(TypeError):
+
+      @absltest.skipThisClass('reason')
+      def test_method():  # pylint: disable=unused-variable
+        pass
+
+  def test_correct_decorator_class(self):
+
+    @absltest.skipThisClass('reason')
+    class Test(absltest.TestCase):
+      pass
+
+    with self.assertRaises(absltest.SkipTest):
+      Test.setUpClass()
+
+  def test_correct_decorator_subclass(self):
+
+    @absltest.skipThisClass('reason')
+    class Test(absltest.TestCase):
+      pass
+
+    class Subclass(Test):
+      pass
+
+    with self.subTest('Base class should be skipped'):
+      with self.assertRaises(absltest.SkipTest):
+        Test.setUpClass()
+
+    with self.subTest('Subclass should not be skipped'):
+      Subclass.setUpClass()  # should not raise.
+
+  def test_setup(self):
+
+    @absltest.skipThisClass('reason')
+    class Test(absltest.TestCase):
+
+      @classmethod
+      def setUpClass(cls):
+        super(Test, cls).setUpClass()
+        cls.foo = 1
+
+    class Subclass(Test):
+      pass
+
+    Subclass.setUpClass()
+    self.assertEqual(Subclass.foo, 1)
+
+  def test_setup_chain(self):
+
+    @absltest.skipThisClass('reason')
+    class BaseTest(absltest.TestCase):
+
+      @classmethod
+      def setUpClass(cls):
+        super(BaseTest, cls).setUpClass()
+        cls.foo = 1
+
+    @absltest.skipThisClass('reason')
+    class SecondBaseTest(BaseTest):
+
+      @classmethod
+      def setUpClass(cls):
+        super(SecondBaseTest, cls).setUpClass()
+        cls.bar = 2
+
+    class Subclass(SecondBaseTest):
+      pass
+
+    Subclass.setUpClass()
+    self.assertEqual(Subclass.foo, 1)
+    self.assertEqual(Subclass.bar, 2)
+
+  def test_setup_args(self):
+
+    @absltest.skipThisClass('reason')
+    class Test(absltest.TestCase):
+
+      @classmethod
+      def setUpClass(cls, foo, bar=None):
+        super(Test, cls).setUpClass()
+        cls.foo = foo
+        cls.bar = bar
+
+    class Subclass(Test):
+
+      @classmethod
+      def setUpClass(cls):
+        super(Subclass, cls).setUpClass('foo', bar='baz')
+
+    Subclass.setUpClass()
+    self.assertEqual(Subclass.foo, 'foo')
+    self.assertEqual(Subclass.bar, 'baz')
+
+  def test_setup_multiple_inheritance(self):
+
+    # Test that skipping this class doesn't break the MRO chain and stop
+    # RequiredBase.setUpClass from running.
+    @absltest.skipThisClass('reason')
+    class Left(absltest.TestCase):
+      pass
+
+    class RequiredBase(absltest.TestCase):
+
+      @classmethod
+      def setUpClass(cls):
+        super(RequiredBase, cls).setUpClass()
+        cls.foo = 'foo'
+
+    class Right(RequiredBase):
+
+      @classmethod
+      def setUpClass(cls):
+        super(Right, cls).setUpClass()
+
+    # Test will fail unless Left.setUpClass() follows mro properly
+    # Right.setUpClass()
+    class Subclass(Left, Right):
+
+      @classmethod
+      def setUpClass(cls):
+        super(Subclass, cls).setUpClass()
+
+    class Test(Subclass):
+      pass
+
+    Test.setUpClass()
+    self.assertEqual(Test.foo, 'foo')
+
+  def test_skip_class(self):
+
+    @absltest.skipThisClass('reason')
+    class BaseTest(absltest.TestCase):
+
+      def test_foo(self):
+        _ = 1 / 0
+
+    class Test(BaseTest):
+
+      def test_foo(self):
+        self.assertEqual(1, 1)
+
+    with self.subTest('base class'):
+      ts = unittest.makeSuite(BaseTest)
+      self.assertEqual(1, ts.countTestCases())
+
+      res = unittest.TestResult()
+      ts.run(res)
+      self.assertTrue(res.wasSuccessful())
+      self.assertLen(res.skipped, 1)
+      self.assertEqual(0, res.testsRun)
+      self.assertEmpty(res.failures)
+      self.assertEmpty(res.errors)
+
+    with self.subTest('real test'):
+      ts = unittest.makeSuite(Test)
+      self.assertEqual(1, ts.countTestCases())
+
+      res = unittest.TestResult()
+      ts.run(res)
+      self.assertTrue(res.wasSuccessful())
+      self.assertEqual(1, res.testsRun)
+      self.assertEmpty(res.skipped)
+      self.assertEmpty(res.failures)
+      self.assertEmpty(res.errors)
+
+  def test_skip_class_unittest(self):
+
+    @absltest.skipThisClass('reason')
+    class Test(unittest.TestCase):  # note: unittest not absltest
+
+      def test_foo(self):
+        _ = 1 / 0
+
+    ts = unittest.makeSuite(Test)
+    self.assertEqual(1, ts.countTestCases())
+
+    res = unittest.TestResult()
+    ts.run(res)
+    self.assertTrue(res.wasSuccessful())
+    self.assertLen(res.skipped, 1)
+    self.assertEqual(0, res.testsRun)
+    self.assertEmpty(res.failures)
+    self.assertEmpty(res.errors)
+
+
 def _listdir_recursive(path):
   for dirname, _, filenames in os.walk(path):
     yield dirname