Fix handling of if_true/if_false in ATen
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index d4570ec..a1e6328 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -1,6 +1,12 @@
import re
from code_template import CodeTemplate
+import sys
+if sys.version_info[0] == 3:
+ string_type = str
+else:
+ string_type = basestring
+
# temporary things we cannot handle
EXCLUDE_PATTERN = "bernoulli.*|normal.*|exponential.*|random.*|arange.*"
# what has to be done to add a Operation ...
@@ -272,14 +278,21 @@
def requires_checked_cast(argument):
return argument['type'] in CHECKED_CAST
+ def bool_option_is_string(argument):
+ return 'if_true' in argument and isinstance(argument['if_true'], string_type)
+
def get_argument(argument, option):
if requires_checked_cast(argument):
return CHECKED_USE.get(argument['type'], '{}_').format(argument['name'])
elif argument['type'] == 'bool' and 'if_true' in argument:
- return '({}) ? "{}" : "{}"'.format(argument['name'],
- argument['if_true'], argument['if_false'])
+ if bool_option_is_string(argument):
+ tpl = '({}) ? "{}" : "{}"'
+ else:
+ tpl = '({}) ? {} : {}'
+ return tpl.format(argument['name'],
+ argument['if_true'], argument['if_false'])
elif argument['type'] == "CONSTANT":
- if 'if_true' in argument: # this was a bool that is actually a string...
+ if bool_option_is_string(argument): # this is a bool that is actually a string...
return '"{}"'.format(argument['name'])
v = str(argument['name'])
for pattern, replacement in CONSTANT_REPLACEMENTS: