blob: b4c8bbe3fb7c9c5cae7251d873a943e57ebd5f7a [file] [log] [blame]
import torch
from torch import _C
import sys
_sparse_tensor_classes = set()
class DoubleTensor(_C.SparseDoubleTensorBase):
def is_signed(self):
return True
class FloatTensor(_C.SparseFloatTensorBase):
def is_signed(self):
return True
class LongTensor(_C.SparseLongTensorBase):
def is_signed(self):
return True
class IntTensor(_C.SparseIntTensorBase):
def is_signed(self):
return True
class ShortTensor(_C.SparseShortTensorBase):
def is_signed(self):
return True
class CharTensor(_C.SparseCharTensorBase):
def is_signed(self):
# TODO
return False
class ByteTensor(_C.SparseByteTensorBase):
def is_signed(self):
return False
_sparse_tensor_classes.add(DoubleTensor)
_sparse_tensor_classes.add(FloatTensor)
_sparse_tensor_classes.add(LongTensor)
_sparse_tensor_classes.add(IntTensor)
_sparse_tensor_classes.add(ShortTensor)
_sparse_tensor_classes.add(CharTensor)
_sparse_tensor_classes.add(ByteTensor)
torch._tensor_classes.update(_sparse_tensor_classes)
_C._sparse_init()