match PyTorch syntax
diff --git a/Local.cwrap b/Local.cwrap
index 5f8d931..2792428 100644
--- a/Local.cwrap
+++ b/Local.cwrap
@@ -73,12 +73,12 @@
[[
name: cat
- variants: [method]
+ cname: catArray
+ variants: [function]
return: self
arguments:
- arg: THTensor* self
+ output: True
- TensorList tensors
- int dim
- aten_custom_call: |
- ${THTensor}_catArray(${state,}self_->tensor, tensors_.data(), tensors_.size(), dim);
]]
diff --git a/Utils.h b/Utils.h
index cb3e1f6..d8e56d2 100644
--- a/Utils.h
+++ b/Utils.h
@@ -34,7 +34,7 @@
}
auto result = dynamic_cast<T*>(expr);
if (result) {
- casted.push_back(result->tensor);
+ casted[i] = result->tensor;
} else {
runtime_error("Expected a Tensor of type %s but found a type %s for sequence element %u "
" in sequence argument at position #%d '%s'",
diff --git a/function_wrapper.py b/function_wrapper.py
index 9b6add1..0ebac0d 100644
--- a/function_wrapper.py
+++ b/function_wrapper.py
@@ -109,6 +109,7 @@
'THIntegerTensor*': '{}_->tensor',
'THStorage*': '{}_->storage',
'THGenerator*': '{}_->generator',
+ 'TensorList': "{0}_.data(), {0}_.size()",
}
ALLOC_WRAP = {
@@ -376,7 +377,8 @@
# dim() == 0 of all input tensors is and'd to form
# the test for whether the output is also a scalar
- if not arg.get('output') and 'Tensor' in arg['type'] and not scalar_check_is_from_size:
+ if (not arg.get('output') and 'Tensor' in arg['type'] and
+ 'TensorList' not in arg['type'] and not scalar_check_is_from_size):
check = '{}.dim() == 0'.format(arg['name'])
scalar_check = (check if scalar_check is None
else scalar_check + ' && ' + check)