| import argparse |
| |
| from common import SubTensor, SubWithTorchFunction, WithTorchFunction # noqa: F401 |
| |
| import torch |
| |
| |
| Tensor = torch.tensor |
| |
| NUM_REPEATS = 1000000 |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Run the torch.add for a given class a given number of times." |
| ) |
| parser.add_argument( |
| "tensor_class", metavar="TensorClass", type=str, help="The class to benchmark." |
| ) |
| parser.add_argument( |
| "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats." |
| ) |
| args = parser.parse_args() |
| |
| TensorClass = globals()[args.tensor_class] |
| NUM_REPEATS = args.nreps |
| |
| t1 = TensorClass([1.0]) |
| t2 = TensorClass([2.0]) |
| |
| for _ in range(NUM_REPEATS): |
| torch.add(t1, t2) |