| def forward(self, input): |
| return None |
| |
| def eqBool(self, input): |
| # type: (bool) -> bool |
| return input |
| |
| def eqInt(self, input): |
| # type: (int) -> int |
| return input |
| |
| def eqFloat(self, input): |
| # type: (float) -> float |
| return input |
| |
| def eqStr(self, input): |
| # type: (str) -> str |
| return input |
| |
| def eqTensor(self, input): |
| # type: (Tensor) -> Tensor |
| return input |
| |
| def eqDictStrKeyIntValue(self, input): |
| # type: (Dict[str, int]) -> Dict[str, int] |
| return input |
| |
| def eqDictIntKeyIntValue(self, input): |
| # type: (Dict[int, int]) -> Dict[int, int] |
| return input |
| |
| def eqDictFloatKeyIntValue(self, input): |
| # type: (Dict[float, int]) -> Dict[float, int] |
| return input |
| |
| def listIntSumReturnTuple(self, input): |
| # type: (List[int]) -> Tuple[List[int], int] |
| sum = 0 |
| for x in input: |
| sum += x |
| return (input, sum) |
| |
| def listBoolConjunction(self, input): |
| # type: (List[bool]) -> bool |
| res = True |
| for x in input: |
| res = res and x |
| return res |
| |
| def listBoolDisjunction(self, input): |
| # type: (List[bool]) -> bool |
| res = False |
| for x in input: |
| res = res or x |
| return res |
| |
| def tupleIntSumReturnTuple(self, input): |
| # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] |
| sum = 0 |
| for x in input: |
| sum += x |
| return (input, sum) |
| |
| def optionalIntIsNone(self, input): |
| # type: (Optional[int]) -> bool |
| return input is None |
| |
| def intEq0None(self, input): |
| # type: (int) -> Optional[int] |
| if input == 0: |
| return None |
| return input |
| |
| def str3Concat(self, input): |
| # type: (str) -> str |
| return input + input + input |
| |
| def newEmptyShapeWithItem(self, input): |
| return torch.tensor([int(input.item())])[0] |
| |
| def testAliasWithOffset(self): |
| # type: () -> List[Tensor] |
| x = torch.tensor([100, 200]) |
| a = [x[0], x[1]] |
| return a |
| |
| def testNonContiguous(self): |
| x = torch.tensor([100, 200, 300])[::2] |
| assert not x.is_contiguous() |
| assert x[0] == 100 |
| assert x[1] == 300 |
| return x |
| |
| def conv2d(self, x, w, toChannelsLast): |
| # type: (Tensor, Tensor, bool) -> Tensor |
| r = torch.conv2d(x, w) |
| if (toChannelsLast): |
| # memory_format=torch.channels_last |
| r = r.contiguous(memory_format=2) |
| else: |
| r = r.contiguous() |
| return r |
| |
| def contiguous(self, x): |
| # type: (Tensor) -> Tensor |
| return x.contiguous() |
| |
| def contiguousChannelsLast(self, x): |
| # type: (Tensor) -> Tensor |
| # memory_format=torch.channels_last |
| return x.contiguous(memory_format=2) |
| |
| def contiguousChannelsLast3d(self, x): |
| # type: (Tensor) -> Tensor |
| # memory_format=torch.channels_last_3d |
| return x.contiguous(memory_format=3) |