| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import functools |
| import os |
| import unittest |
| import sys |
| import torch |
| import torch.autograd.function as function |
| |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.insert(-1, pytorch_test_dir) |
| |
| from common_utils import * # noqa: F401 |
| |
| torch.set_default_tensor_type('torch.FloatTensor') |
| |
| |
| def _skipper(condition, reason): |
| def decorator(f): |
| @functools.wraps(f) |
| def wrapper(*args, **kwargs): |
| if condition(): |
| raise unittest.SkipTest(reason) |
| return f(*args, **kwargs) |
| return wrapper |
| return decorator |
| |
| |
| skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), |
| 'CUDA is not available') |
| |
| skipIfTravis = _skipper(lambda: os.getenv('TRAVIS'), |
| 'Skip In Travis') |
| |
| |
| def flatten(x): |
| return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x)) |