blob: 551c712ed38bb991261465e1e8ac9c00afa0c663 [file] [log] [blame]
# https://pytorch.org/docs/stable/torch.html#math-operations
import math
import torch
class PointwiseOpsModule(torch.nn.Module):
def __init__(self):
super(PointwiseOpsModule, self).__init__()
def forward(self):
return self.pointwise_ops()
def pointwise_ops(self):
a = torch.randn(4)
b = torch.randn(4)
t = torch.tensor([-1, -2, 3], dtype=torch.int8)
r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
t = torch.tensor([-1, -2, 3], dtype=torch.int8)
s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
f = torch.zeros(3)
g = torch.tensor([-1, 0, 1])
w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
return len(
torch.abs(torch.tensor([-1, -2, 3])),
torch.absolute(torch.tensor([-1, -2, 3])),
torch.acos(a),
torch.arccos(a),
torch.acosh(a.uniform_(1.0, 2.0)),
torch.add(a, 20),
torch.add(a, b, out=a),
b.add(a),
b.add(a, out=b),
b.add_(a),
b.add(1),
torch.add(a, torch.randn(4, 1), alpha=10),
torch.addcdiv(
torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1
),
torch.addcmul(
torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1
),
torch.angle(a),
torch.asin(a),
torch.arcsin(a),
torch.asinh(a),
torch.arcsinh(a),
torch.atan(a),
torch.arctan(a),
torch.atanh(a.uniform_(-1.0, 1.0)),
torch.arctanh(a.uniform_(-1.0, 1.0)),
torch.atan2(a, a),
torch.bitwise_not(t),
torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
torch.ceil(a),
torch.ceil(float(torch.tensor(0.5))),
torch.ceil(torch.tensor(0.5).item()),
torch.clamp(a, min=-0.5, max=0.5),
torch.clamp(a, min=0.5),
torch.clamp(a, max=0.5),
torch.clip(a, min=-0.5, max=0.5),
torch.conj(a),
torch.copysign(a, 1),
torch.copysign(a, b),
torch.cos(a),
torch.cosh(a),
torch.deg2rad(
torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])
),
torch.div(a, b),
a.div(b),
a.div(1),
a.div_(b),
torch.divide(a, b, rounding_mode="trunc"),
torch.divide(a, b, rounding_mode="floor"),
torch.digamma(torch.tensor([1.0, 0.5])),
torch.erf(torch.tensor([0.0, -1.0, 10.0])),
torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
torch.exp(torch.tensor([0.0, math.log(2.0)])),
torch.exp(float(torch.tensor(1))),
torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
torch.expm1(torch.tensor([0.0, math.log(2.0)])),
torch.fake_quantize_per_channel_affine(
torch.randn(2, 2, 2),
(torch.randn(2) + 1) * 0.05,
torch.zeros(2),
1,
0,
255,
),
torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
torch.float_power(torch.randint(10, (4,)), 2),
torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])),
torch.floor(a),
torch.floor(float(torch.tensor(1))),
torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
torch.frac(torch.tensor([1.0, 2.5, -3.2])),
torch.randn(4, dtype=torch.cfloat).imag,
torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5),
torch.lerp(
torch.arange(1.0, 5.0),
torch.empty(4).fill_(10),
torch.full_like(torch.arange(1.0, 5.0), 0.5),
),
torch.lgamma(torch.arange(0.5, 2, 0.5)),
torch.log(torch.arange(5) + 10),
torch.log10(torch.rand(5)),
torch.log1p(torch.randn(5)),
torch.log2(torch.rand(5)),
torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
torch.logaddexp(
torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])
),
torch.logaddexp(
torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])
),
torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
torch.logaddexp2(
torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])
),
torch.logaddexp2(
torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])
),
torch.logical_and(r, s),
torch.logical_and(r.double(), s.double()),
torch.logical_and(r.double(), s),
torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
torch.logical_not(torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
torch.logical_not(
torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
out=torch.empty(3, dtype=torch.int16),
),
torch.logical_or(r, s),
torch.logical_or(r.double(), s.double()),
torch.logical_or(r.double(), s),
torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
torch.logical_xor(r, s),
torch.logical_xor(r.double(), s.double()),
torch.logical_xor(r.double(), s),
torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
torch.logit(torch.rand(5), eps=1e-6),
torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
torch.i0(torch.arange(5, dtype=torch.float32)),
torch.igamma(a, b),
torch.igammac(a, b),
torch.mul(torch.randn(3), 100),
b.mul(a),
b.mul(5),
b.mul(a, out=b),
b.mul_(a),
b.mul_(5),
torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]),
torch.nan_to_num(w),
torch.nan_to_num_(w),
torch.nan_to_num(w, nan=2.0),
torch.nan_to_num(w, nan=2.0, posinf=1.0),
torch.neg(torch.randn(5)),
# torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
torch.polygamma(1, torch.tensor([1.0, 0.5])),
torch.polygamma(2, torch.tensor([1.0, 0.5])),
torch.polygamma(3, torch.tensor([1.0, 0.5])),
torch.polygamma(4, torch.tensor([1.0, 0.5])),
torch.pow(a, 2),
torch.pow(2, float(torch.tensor(0.5))),
torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
torch.rad2deg(
torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])
),
torch.randn(4, dtype=torch.cfloat).real,
torch.reciprocal(a),
torch.remainder(torch.tensor([-3.0, -2.0]), 2),
torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
torch.round(a),
torch.round(torch.tensor(0.5).item()),
torch.rsqrt(a),
torch.sigmoid(a),
torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
torch.sgn(a),
torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
torch.sin(a),
torch.sinc(a),
torch.sinh(a),
torch.sqrt(a),
torch.square(a),
torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
b.sub(a),
b.sub_(a),
b.sub(5),
torch.sum(5),
torch.tan(a),
torch.tanh(a),
torch.true_divide(a, a),
torch.trunc(a),
torch.trunc_(a),
torch.xlogy(f, g),
torch.xlogy(f, g),
torch.xlogy(f, 4),
torch.xlogy(2, g),
)
class ReductionOpsModule(torch.nn.Module):
def __init__(self):
super(ReductionOpsModule, self).__init__()
def forward(self):
return self.reduction_ops()
def reduction_ops(self):
a = torch.randn(4)
b = torch.randn(4)
c = torch.tensor(0.5)
return len(
torch.argmax(a),
torch.argmin(a),
torch.amax(a),
torch.amin(a),
torch.aminmax(a),
torch.all(a),
torch.any(a),
torch.max(a),
a.max(a),
torch.max(a, 0),
torch.min(a),
a.min(a),
torch.min(a, 0),
torch.dist(a, b),
torch.logsumexp(a, 0),
torch.mean(a),
torch.mean(a, 0),
torch.nanmean(a),
torch.median(a),
torch.nanmedian(a),
torch.mode(a),
torch.norm(a),
a.norm(2),
torch.norm(a, dim=0),
torch.norm(c, torch.tensor(2)),
torch.nansum(a),
torch.prod(a),
torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])),
torch.quantile(a, 0.5),
torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])),
torch.std(a),
torch.std_mean(a),
torch.sum(a),
torch.unique(a),
torch.unique_consecutive(a),
torch.var(a),
torch.var_mean(a),
torch.count_nonzero(a),
)
class ComparisonOpsModule(torch.nn.Module):
def __init__(self):
super(ComparisonOpsModule, self).__init__()
def forward(self):
a = torch.tensor(0)
b = torch.tensor(1)
return len(
torch.allclose(a, b),
torch.argsort(a),
torch.eq(a, b),
torch.eq(a, 1),
torch.equal(a, b),
torch.ge(a, b),
torch.ge(a, 1),
torch.greater_equal(a, b),
torch.greater_equal(a, 1),
torch.gt(a, b),
torch.gt(a, 1),
torch.greater(a, b),
torch.isclose(a, b),
torch.isfinite(a),
torch.isin(a, b),
torch.isinf(a),
torch.isposinf(a),
torch.isneginf(a),
torch.isnan(a),
torch.isreal(a),
torch.kthvalue(a, 1),
torch.le(a, b),
torch.le(a, 1),
torch.less_equal(a, b),
torch.lt(a, b),
torch.lt(a, 1),
torch.less(a, b),
torch.maximum(a, b),
torch.minimum(a, b),
torch.fmax(a, b),
torch.fmin(a, b),
torch.ne(a, b),
torch.ne(a, 1),
torch.not_equal(a, b),
torch.sort(a),
torch.topk(a, 1),
torch.msort(a),
)
class OtherMathOpsModule(torch.nn.Module):
def __init__(self):
super(OtherMathOpsModule, self).__init__()
def forward(self):
return self.other_ops()
def other_ops(self):
a = torch.randn(4)
b = torch.randn(4)
c = torch.randint(0, 8, (5,), dtype=torch.int64)
e = torch.randn(4, 3)
f = torch.randn(4, 4, 4)
size = [0, 1]
dims = [0, 1]
return len(
torch.atleast_1d(a),
torch.atleast_2d(a),
torch.atleast_3d(a),
torch.bincount(c),
torch.block_diag(a),
torch.broadcast_tensors(a),
torch.broadcast_to(a, (4)),
# torch.broadcast_shapes(a),
torch.bucketize(a, b),
torch.cartesian_prod(a),
torch.cdist(e, e),
torch.clone(a),
torch.combinations(a),
torch.corrcoef(a),
# torch.cov(a),
torch.cross(e, e),
torch.cummax(a, 0),
torch.cummin(a, 0),
torch.cumprod(a, 0),
torch.cumsum(a, 0),
torch.diag(a),
torch.diag_embed(a),
torch.diagflat(a),
torch.diagonal(e),
torch.diff(a),
torch.einsum("iii", f),
torch.flatten(a),
torch.flip(e, dims),
torch.fliplr(e),
torch.flipud(e),
torch.kron(a, b),
torch.rot90(e),
torch.gcd(c, c),
torch.histc(a),
torch.histogram(a),
torch.meshgrid(a),
torch.meshgrid(a, indexing="xy"),
torch.lcm(c, c),
torch.logcumsumexp(a, 0),
torch.ravel(a),
torch.renorm(e, 1, 0, 5),
torch.repeat_interleave(c),
torch.roll(a, 1, 0),
torch.searchsorted(a, b),
torch.tensordot(e, e),
torch.trace(e),
torch.tril(e),
torch.tril_indices(3, 3),
torch.triu(e),
torch.triu_indices(3, 3),
torch.vander(a),
torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
torch.view_as_complex(torch.randn(4, 2)).real,
torch.resolve_conj(a),
torch.resolve_neg(a),
)
class SpectralOpsModule(torch.nn.Module):
def __init__(self):
super(SpectralOpsModule, self).__init__()
def forward(self):
return self.spectral_ops()
def spectral_ops(self):
a = torch.randn(10)
b = torch.randn(10, 8, 4, 2)
return len(
torch.stft(a, 8),
torch.stft(a, torch.tensor(8)),
torch.istft(b, 8),
torch.bartlett_window(2, dtype=torch.float),
torch.blackman_window(2, dtype=torch.float),
torch.hamming_window(4, dtype=torch.float),
torch.hann_window(4, dtype=torch.float),
torch.kaiser_window(4, dtype=torch.float),
)
class BlasLapackOpsModule(torch.nn.Module):
def __init__(self):
super(BlasLapackOpsModule, self).__init__()
def forward(self):
return self.blas_lapack_ops()
def blas_lapack_ops(self):
m = torch.randn(3, 3)
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 3)
v = torch.randn(3)
return len(
torch.addbmm(m, a, b),
torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)),
torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
torch.addr(torch.zeros(3, 3), v, v),
torch.baddbmm(m, a, b),
torch.bmm(a, b),
torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)),
# torch.cholesky(a), # deprecated
# torch.cholesky_inverse(torch.randn(3, 3)), # had some error
# torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)),
torch.dot(v, v),
# torch.linalg.eig(m), # not build with lapack
# torch.geqrf(a),
torch.ger(v, v),
torch.inner(m, m),
# torch.inverse(m),
# torch.det(m),
# torch.logdet(m),
# torch.slogdet(m),
# torch.lstsq(m, m),
# torch.linalg.lu_factor(m),
# torch.lu_solve(m, *torch.linalg.lu_factor(m)),
# torch.lu_unpack(*torch.linalg.lu_factor(m)),
torch.matmul(m, m),
torch.matrix_power(m, 2),
# torch.matrix_rank(m),
torch.matrix_exp(m),
torch.mm(m, m),
torch.mv(m, v),
# torch.orgqr(a, m),
# torch.ormqr(a, m, v),
torch.outer(v, v),
# torch.pinverse(m),
# torch.qr(a),
# torch.solve(m, m),
# torch.svd(a),
# torch.svd_lowrank(a),
# torch.pca_lowrank(a),
# torch.symeig(a), # deprecated
# torch.lobpcg(a, b), # not supported
torch.trapz(m, m),
torch.trapezoid(m, m),
torch.cumulative_trapezoid(m, m),
# torch.triangular_solve(m, m),
torch.vdot(v, v),
)