process types and processors
diff --git a/aten/src/aten/preprocess_declarations.py b/aten/src/aten/preprocess_declarations.py
index 55d2293..bc88a74 100644
--- a/aten/src/aten/preprocess_declarations.py
+++ b/aten/src/aten/preprocess_declarations.py
@@ -1,5 +1,98 @@
import common_with_cwrap
+cpu_floating_point = set([
+ 'float',
+ 'double',
+])
+
+cpu_integral = set([
+ 'byte',
+ 'char',
+ 'short',
+ 'int',
+ 'long'
+])
+
+cpu_types = cpu_floating_point | cpu_integral
+cpu_type_map = {
+ 'floating_point': cpu_floating_point,
+ 'integral': cpu_integral,
+ 'all': cpu_types
+}
+
+cuda_floating_point = cpu_floating_point | set(['half'])
+cuda_integral = cpu_integral
+cuda_types = cuda_floating_point | cuda_integral
+cuda_type_map = {
+ 'floating_point': cuda_floating_point,
+ 'integral': cuda_integral,
+ 'all': cuda_types
+}
+
+all_types = cpu_types | cuda_types
+
+processor_types = set([
+ 'cpu',
+ 'cuda',
+])
+
+processor_type_map = {
+ 'cpu': cpu_type_map,
+ 'cuda': cuda_type_map,
+}
+
+
+def process_types_and_processors(option):
+ # First, check if there are no types, processors, specifed. If so we assume
+ # that the method/function operates on all types and processors
+ if ('types' not in option and 'processors' not in option and
+ 'type_processor_pairs' not in option):
+ return option
+
+ # First, get the full set of types. If there are no types specified, but a
+ # processor is specified, assume we meant all types for that processor
+ processors = option['processors']
+ types = option['types'] if 'types' in option else []
+
+ if len(types) == 0:
+ assert(len(processors) == 1)
+ if processors[0] == 'cpu':
+ types = list(cpu_types)
+ elif processors[0] == 'cuda':
+ types = list(cuda_types)
+ else:
+ assert(False)
+
+ pairs = {}
+
+ # generate pairs for all processors
+ for processor in processors:
+ assert(processor in processor_types)
+ type_map = processor_type_map[processor]
+ for tstr in types:
+ # handle possible expansion
+ type_list = type_map[tstr] if tstr in type_map else [tstr]
+ for t in type_list:
+ assert(t in type_map['all'])
+ pairs[t] = processor
+
+ # if there are any prespecified tuples, handle them now
+ predefined_pairs = (option['type_processor_pairs'] if 'type_processor_pairs'
+ in option else {})
+ for tstr in predefined_pairs:
+ pr = predefined_pairs[tstr]
+ assert(pr in processor_types)
+ type_map = processor_type_map[processor]
+ # handle possible expansion
+ type_list = type_map[tstr] if tstr in type_map else [tstr]
+ for t in type_list:
+ assert(t in type_map['all'])
+ pairs[t] = pr
+
+ option['processor_type_pairs'] = pairs
+ return option
+
+
def exclude(declaration):
return 'only_register' in declaration
@@ -19,6 +112,10 @@
common_with_cwrap.set_declaration_defaults(declaration)
common_with_cwrap.enumerate_options_due_to_default(declaration)
common_with_cwrap.sort_by_number_of_options(declaration)
- add_variants(declaration)
+
+ declarations = [d for d in declarations if not exclude(d)]
+ declarations = [process_types_and_processors(d) for d in declarations]
+
+ add_variants(declaration)
return declarations