blob: 3b54967c5bacbb5137ad9bb8c3be27291c3d1099 [file] [log] [blame]
from typing import Dict, List, Tuple, Union, Any, Callable, Set
import torch
"""Useful functions to deal with tensor types with other python container types."""
def _apply_to_tensors(
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)