| from typing import Dict, List, Tuple, Union, Any, Callable, Set |
| |
| import torch |
| |
| |
| """Useful functions to deal with tensor types with other python container types.""" |
| |
| |
| def _apply_to_tensors( |
| fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set] |
| ) -> Any: |
| """Recursively apply to all tensor in different kinds of container types.""" |
| |
| def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: |
| if torch.is_tensor(x): |
| return fn(x) |
| elif isinstance(x, dict): |
| return {key: apply(value) for key, value in x.items()} |
| elif isinstance(x, (list, tuple, set)): |
| return type(x)(apply(el) for el in x) |
| else: |
| return x |
| |
| return apply(container) |