caffe2/caffe2/contrib/script (#15007)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15007
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14979
att
Reviewed By: dzhulgakov
Differential Revision: D13286191
fbshipit-source-id: b8a6bc7aea44487aea4dcf7f44c858fd30c6293c
diff --git a/caffe2/contrib/CMakeLists.txt b/caffe2/contrib/CMakeLists.txt
index be8c0bd..6034e4d 100644
--- a/caffe2/contrib/CMakeLists.txt
+++ b/caffe2/contrib/CMakeLists.txt
@@ -4,7 +4,6 @@
add_subdirectory(opencl)
add_subdirectory(prof)
add_subdirectory(shm_mutex)
-add_subdirectory(script)
if (USE_TENSORRT)
add_subdirectory(tensorrt)
endif()
diff --git a/caffe2/contrib/script/CMakeLists.txt b/caffe2/contrib/script/CMakeLists.txt
deleted file mode 100644
index fb38787..0000000
--- a/caffe2/contrib/script/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-# ---[ CPU files.
-file(GLOB tmp *.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
-# exclude test files and gpu files
-file(GLOB tmp *_test.cc)
-exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp})
-exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${Caffe2_GPU_SRCS})
-
-# ---[ CPU test files
-file(GLOB tmp *_test.cc)
-set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
-exclude(Caffe2_CPU_TEST_SRCS "${Caffe2_CPU_TEST_SRCS}" ${Caffe2_GPU_TEST_SRCS})
-
-# ---[ Send the lists to the parent scope.
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
-set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
diff --git a/caffe2/contrib/script/caffe2_script_test.py b/caffe2/contrib/script/caffe2_script_test.py
deleted file mode 100644
index d9f0b65..0000000
--- a/caffe2/contrib/script/caffe2_script_test.py
+++ /dev/null
@@ -1,520 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-from hypothesis import given
-
-from caffe2.python import core, workspace
-from caffe2.proto import caffe2_pb2
-import caffe2.python.hypothesis_test_util as hu
-import hypothesis.strategies as st
-
-import numpy as np
-
-
-def feed_inputs(inputs):
- for name, value in inputs.items():
- workspace.FeedBlob(name, value)
-
-
-def assert_proto_equals(proto, expected):
- proto_lines = proto.strip().split('\n')
- expected_lines = expected.strip().split('\n')
- assert len(proto_lines) == len(expected_lines), \
- '{} != {}'.format(proto, expected)
- for left, right in zip(proto_lines, expected_lines):
- assert left.strip() == right.strip(), \
- '{} != {}'.format(proto, expected)
-
-
-class TestCaffe2Script(hu.HypothesisTestCase):
- test_program = """
- def foo(a,b,X,W) -> (c):
- t = a + b*b
- c = FC(X,W,t)
- def testIf(c0,c1,t,f) -> (r):
- if c0 < c1:
- r = t
- else:
- r = f
- r = Add(r,3f,broadcast=1)
- def testWhile(r) -> (r):
- m = 0
- while m < 4:
- # Plus operator automatically broadcasts, and we cannot
- # do in-place B and C arguments when we broadcast, so use
- # an explicit Add op.
- r = Add(r, r)
- m = m + 1
- """
-
- @given(firstdim=st.integers(min_value=1, max_value=4096),
- seconddim=st.integers(min_value=1, max_value=4096),
- seed=st.integers(min_value=0, max_value=65536),
- **hu.gcs)
- def test_foo(self, firstdim, seconddim, seed, gc, dc):
- np.random.seed(int(seed))
- inputs = {}
- a = inputs['a'] = np.random.rand(seconddim).astype(np.float32)
- b = inputs['b'] = np.random.rand(seconddim).astype(np.float32)
- X = inputs['X'] = np.random.rand(firstdim, firstdim).astype(np.float32)
- W = inputs['W'] = np.random.rand(seconddim, firstdim).astype(np.float32)
-
- feed_inputs(inputs)
-
- CU = core.C.CompilationUnit()
- CU.define(self.test_program)
- CU.create_net('foo').run()
-
- ref_t = a + b * b
- ref_c = np.matmul(X, W.transpose()) + ref_t
- actual_c = workspace.FetchBlob('c')
-
- np.testing.assert_allclose(actual_c, ref_c, rtol=1e-05)
-
- def test_trinary(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo(c) -> (d):
- d = 1 + (2 if c else 4)
- """)
- workspace.FeedBlob('c', np.ones((1), dtype=bool))
- net = CU.create_net('foo')
- net.run()
- assert(3 == workspace.FetchBlob('d'))
- workspace.FeedBlob('c', np.zeros((1), dtype=bool))
- net.run()
- assert(5 == workspace.FetchBlob('d'))
-
- def test_bool_literal(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> (a,b):
- a = True
- b = False
- """)
- net = CU.create_net('foo')
- net.run()
- assert(workspace.FetchBlob('a'))
- assert(not workspace.FetchBlob('b'))
-
- def test_bool_operators(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> (a, b, c, d, e):
- a = True and False
- b = True or False
- c = not b
- d = not False or True
- e = not (1 if a else 0) == (1 if b else 0)
- """)
- net = CU.create_net('foo')
- net.run()
- assert(not workspace.FetchBlob('a'))
- assert(workspace.FetchBlob('b'))
- assert(not workspace.FetchBlob('c'))
- assert(workspace.FetchBlob('d'))
- assert(workspace.FetchBlob('e'))
-
- def expect_fail(self, fn, msg):
- try:
- fn()
- except RuntimeError as r:
- if msg not in str(r):
- raise RuntimeError(
- "Failed wrong: expected string '{}' ".format(msg) +
- "in error message but found\n{}".format(str(r)))
-
- def test_fails(self):
- def fail_inputs():
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> ():
- Print(1,4)
- """)
- self.expect_fail(fail_inputs, "expects 1 inputs but found 2")
-
- def fail_undef():
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo(a) -> (b):
- a = what()
- """)
- self.expect_fail(fail_undef, "attempting to call unknown operation")
-
- def fail_schema():
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo(a) -> (b):
- a = FC(a,a,a)
- """)
- self.expect_fail(fail_schema, "failed schema checking")
-
- def test_print(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> ():
- a = 1
- Print(a)
- Print(a+1)
- _ = 4
- Print(_) # verify in print this isn't _ but some temorary
- Print(1)
- Print(1.f)
- Print(3.0)
- """)
- net = CU.create_net('foo')
- net.run()
-
- def test_method(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> (a):
- a = (3+1).Add(4).Add(1)
- """)
- net = CU.create_net('foo')
- net.run()
- assert(9 == workspace.FetchBlob('a'))
-
- def test_plus_eq(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> (a):
- a = 4
- a += 1
- """)
- net = CU.create_net('foo')
- net.run()
- assert(5 == workspace.FetchBlob('a'))
-
- def test_cast(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> (a):
- a = int(4.5f)
- """)
- net = CU.create_net('foo')
- net.run()
- assert(4 == workspace.FetchBlob('a'))
-
- def test_global(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def foo() -> (a):
- global m
- m.a = 4
- m.b = 5
- a = m.a + m.b
- """)
- net = CU.create_net('foo')
- net.run()
- assert(9 == workspace.FetchBlob('a'))
-
- def test_module_as_arg_ret(self):
- CU = core.C.CompilationUnit()
- CU.define("""
- def bar(a,c) -> (b):
- b = Module()
- temp = a.second
- b.first = temp
- b.second = a.first + c
- def foo() -> (a,b):
- x = Module()
- x.first = 1
- x.second = 2
- x.y = bar(x,4)
- a = x.y.first
- b = x.y.second
- """)
- net = CU.create_net('foo')
- net.run()
- assert(2 == workspace.FetchBlob('a'))
- assert(5 == workspace.FetchBlob('b'))
-
- def test_call_extern(self):
- CU = core.C.CompilationUnit()
- net = caffe2_pb2.NetDef()
- net.op.extend([
- core.CreateOperator(
- 'Mul',
- ['i', 'i'],
- ['o'],
- )
- ])
- net.external_input.append('i')
- net.external_output.append('o')
-
- CU.extern("myActualExtern", net)
- CU.define("""
- def myExtern(x) -> (y):
- t = x
- if t > 1:
- y = t * t
- else:
- y = 5
- def foo() -> (b):
- a = 4
- a += 1
- b = 2 + myExtern(a) + myExtern(a, rename=False) + myActualExtern(a)
- """)
- net = CU.create_net('foo')
- net.run()
- assert(77 == workspace.FetchBlob('b'))
-
- @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
- def test_if(self, seed, gc, dc):
- np.random.seed(int(seed))
- inputs = {}
- c0 = inputs['c0'] = np.random.rand(1).astype(np.float32)
- c1 = inputs['c1'] = np.random.rand(1).astype(np.float32)
- t = inputs['t'] = np.random.rand(3, 3).astype(np.float32)
- f = inputs['f'] = np.random.rand(3, 3).astype(np.float32)
-
- feed_inputs(inputs)
-
- CU = core.C.CompilationUnit()
- CU.define(self.test_program)
- CU.create_net('testIf').run()
-
- if c0 < c1:
- ref_r = t + 3
- else:
- ref_r = f + 3
- actual_r = workspace.FetchBlob('r')
-
- np.testing.assert_allclose(actual_r, ref_r)
-
- @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
- def test_while(self, seed, gc, dc):
- np.random.seed(int(seed))
- inputs = {}
- r = inputs['r'] = np.ones([3, 3]).astype(np.float32)
-
- feed_inputs(inputs)
-
- CU = core.C.CompilationUnit()
- CU.define(self.test_program)
- CU.create_net('testWhile').run()
-
- m = 0
- while m < 4:
- r = r + r
- m = m + 1
-
- actual_r = workspace.FetchBlob('r')
-
- np.testing.assert_allclose(actual_r, r)
-
- @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
- def test_gather(self, seed, gc, dc):
- CU = core.C.CompilationUnit()
- CU.define("""
- def easy(tensor, indices) -> (output):
- output = tensor[indices]
- def hard(tensor, i, j, k) -> (output):
- output = tensor[i][j][k]
- """)
-
- # First check that the generated proto is as expected. This tests that
- # we desugar the gather syntax correctly and emit the right code.
- proto = CU.get_proto('easy')
- assert_proto_equals(proto, """
- name: "easy"
- op {
- input: "tensor"
- input: "indices"
- output: "output"
- type: "Gather"
- }""")
-
- proto = CU.get_proto('hard')
- assert_proto_equals(proto, """
- name: "hard"
- op {
- input: "tensor"
- input: "i"
- output: "$t1"
- type: "Gather"
- }
- op {
- input: "$t1"
- input: "j"
- output: "$t0"
- type: "Gather"
- }
- op {
- input: "$t0"
- input: "k"
- output: "output"
- type: "Gather"
- }""")
-
- # Now just test that the effect of the generated code is as expected.
- np.random.seed(int(seed))
- tensor = np.random.rand(5, 4, 3).astype(np.float32)
- indices = np.random.randint(len(tensor), size=(5, 5))
-
- feed_inputs(dict(tensor=tensor, indices=indices))
-
- net = CU.create_net('easy')
- net.run()
-
- output = workspace.FetchBlob('output')
- expected_output = [tensor[sample] for sample in indices]
- np.testing.assert_allclose(output, expected_output)
-
- @given(seed=st.integers(min_value=0, max_value=65536), **hu.gcs)
- def test_slice(self, seed, gc, dc):
- CU = core.C.CompilationUnit()
- CU.define("""
- def slice_from_tensor(tensor, start, end) -> (output):
- output = tensor[start:end]
- def slice_from_vector(vector, start, end) -> (a, b, c, d):
- a = vector[start:end]
- b = vector[start:]
- c = vector[:end]
- d = vector[:]
- """)
-
- # slice_from_tensor
- proto = CU.get_proto('slice_from_tensor')
- assert_proto_equals(proto, """
- name: "slice_from_tensor"
- op {
- input: "tensor"
- input: "start"
- input: "end"
- output: "output"
- type: "Slice"
- }""")
-
- np.random.seed(int(seed))
- tensor = np.random.rand(5, 4, 3).astype(np.float32)
- start = np.array([0, 1, 0], dtype=np.int32)
- end = np.array([-1, 2, -1], dtype=np.int32)
-
- feed_inputs(dict(tensor=tensor, start=start, end=end))
-
- net = CU.create_net('slice_from_tensor')
- net.run()
-
- output = workspace.FetchBlob('output')
- np.testing.assert_allclose(output, tensor[:, 1:2])
-
- # slice_from_vector
- proto = CU.get_proto('slice_from_vector')
- assert_proto_equals(proto, """
- name: "slice_from_vector"
- op {
- input: "vector"
- input: "start"
- input: "end"
- output: "a"
- type: "Slice"
- }
- op {
- output: "$t0"
- type: "ConstantFill"
- arg {
- name: "dtype"
- i: 2
- }
- arg {
- name: "value"
- i: -1
- }
- arg {
- name: "shape"
- ints: 1
- }
- }
- op {
- input: "vector"
- input: "start"
- input: "$t0"
- output: "b"
- type: "Slice"
- }
- op {
- output: "$t1"
- type: "ConstantFill"
- arg {
- name: "dtype"
- i: 2
- }
- arg {
- name: "value"
- i: 0
- }
- arg {
- name: "shape"
- ints: 1
- }
- }
- op {
- input: "vector"
- input: "$t1"
- input: "end"
- output: "c"
- type: "Slice"
- }
- op {
- output: "$t2"
- type: "ConstantFill"
- arg {
- name: "dtype"
- i: 2
- }
- arg {
- name: "value"
- i: 0
- }
- arg {
- name: "shape"
- ints: 1
- }
- }
- op {
- output: "$t3"
- type: "ConstantFill"
- arg {
- name: "dtype"
- i: 2
- }
- arg {
- name: "value"
- i: -1
- }
- arg {
- name: "shape"
- ints: 1
- }
- }
- op {
- input: "vector"
- input: "$t2"
- input: "$t3"
- output: "d"
- type: "Slice"
- }""")
-
- vector = np.random.rand(10).astype(np.float32)
- start = np.array([2], dtype=np.int32)
- end = np.array([6], dtype=np.int32)
- feed_inputs(dict(vector=vector, start=start, end=end))
-
- net = CU.create_net('slice_from_vector')
- net.run()
-
- output = workspace.FetchBlob('a')
- np.testing.assert_allclose(output, vector[2:6])
-
- output = workspace.FetchBlob('b')
- np.testing.assert_allclose(output, vector[2:])
-
- output = workspace.FetchBlob('c')
- np.testing.assert_allclose(output, vector[:6])
-
- output = workspace.FetchBlob('d')
- np.testing.assert_allclose(output, vector)
diff --git a/caffe2/contrib/script/compiler.cc b/caffe2/contrib/script/compiler.cc
deleted file mode 100644
index 16a7657..0000000
--- a/caffe2/contrib/script/compiler.cc
+++ /dev/null
@@ -1,793 +0,0 @@
-#include "caffe2/core/net.h"
-#include "caffe2/utils/proto_utils.h"
-
-#include "compiler.h"
-#include "parser.h"
-
-namespace caffe2 {
-namespace script {
-
-namespace {
-
-static std::unordered_set<std::string> ops_containing_nets = {
- "If",
- "While",
- "RecurrentNetwork",
-};
-// record of defined function
-// NetDef + metadata
-struct FunctionDefinition {
- explicit FunctionDefinition(Def tree)
- : tree(new Def(tree)), net_def(new NetDef()) {}
-
- explicit FunctionDefinition(std::unique_ptr<NetDef> def)
- : tree(nullptr), net_def(std::move(def)) {
- // we coop extern_inputs/extern_outputs to be the inputs/outputs to
- // this net as a function
- // but we _dont_ set these when creating the net in the workspace
- // because they require the net to have valid inputs/outputs
- inputs.insert(
- inputs.begin(),
- net_def->external_input().begin(),
- net_def->external_input().end());
- outputs.insert(
- outputs.begin(),
- net_def->external_output().begin(),
- net_def->external_output().end());
- net_def->clear_external_output();
- net_def->clear_external_input();
- }
-
- bool isExtern() const {
- return tree == nullptr;
- }
- std::unique_ptr<Def> tree;
- std::unique_ptr<NetDef> net_def;
- std::vector<std::string> inputs;
- std::vector<std::string> outputs;
-};
-
-} // namespace
-
-using SymbolTable = std::unordered_map<std::string, FunctionDefinition>;
-
-struct DefCompiler {
- DefCompiler(FunctionDefinition& def, SymbolTable& symbol_table)
- : def(def),
- net_def_stack({def.net_def.get()}),
- symbol_table(symbol_table) {}
- void run() {
- auto& tree = *def.tree;
- cur().set_name(tree.name().name());
- for (auto input : tree.params()) {
- auto& name = input.ident().name();
- map(name, name);
- def.inputs.push_back(name);
- }
- for (auto output : tree.returns()) {
- auto& name = output.ident().name();
- map(name, name);
- def.outputs.push_back(name);
- }
- emitStatements(tree.statements());
- }
- void emitExpressionStatement(TreeRef stmt) {
- // expression with no used outputs
- emit(stmt, {});
- }
- void emitStatements(const ListView<TreeRef>& statements) {
- for (auto stmt : statements) {
- switch (stmt->kind()) {
- case TK_IF:
- emitIf(If(stmt));
- break;
- case TK_WHILE:
- emitWhile(While(stmt));
- break;
- case TK_ASSIGN:
- emitAssignment(Assign(stmt));
- break;
- case TK_GLOBAL:
- for (auto ident : stmt->trees()) {
- auto name = Ident(ident).name();
- map(name, name);
- }
- break;
- default:
- emitExpressionStatement(stmt);
- break;
- }
- }
- }
- void map(const std::string& name, const std::string& value) {
- env[name] = value;
- }
- const std::string& lookup(const Ident& ident) {
- if (env.count(ident.name()) == 0)
- throw ErrorReport(ident) << "undefined value " << ident.name();
- return env[ident.name()];
- }
- void emitAssignment(const Assign& stmt) {
- std::vector<std::string> outputs;
- for (auto lhs : stmt.lhs()) {
- std::string name = getLHS(lhs);
- // use of "_" gets renamed in Caffe2 graphs so that two uses
- // don't unintentionally interfere with each other
- if (name == "_") {
- name = fresh();
- }
- outputs.push_back(name);
- }
- if (stmt.reduction() != '=') {
- if (stmt.lhs().size() != 1) {
- throw ErrorReport(stmt)
- << "reductions are only allow when there is a single variable "
- << "on the left-hand side.";
- }
- auto lhs = stmt.lhs()[0];
- auto expr =
- Compound::create(stmt.reduction(), stmt.range(), {lhs, stmt.rhs()});
- emit(expr, outputs);
- } else {
- emit(stmt.rhs(), outputs);
- }
- int i = 0;
- for (auto ident : stmt.lhs()) {
- if (ident->kind() == TK_IDENT)
- map(Ident(ident).name(), outputs.at(i));
- i++;
- }
- }
- void emitIf(const If& stmt) {
- auto cond = getValue(stmt.cond());
- auto op = cur().add_op();
- op->set_type("If");
- op->add_input(cond);
- auto true_branch = op->add_arg();
- true_branch->set_name("then_net");
- auto nd = true_branch->mutable_n();
- net_def_stack.push_back(nd);
- emitStatements(stmt.trueBranch());
- net_def_stack.pop_back();
- if (stmt.falseBranch().size() > 0) {
- auto false_branch = op->add_arg();
- false_branch->set_name("else_net");
- auto nd = false_branch->mutable_n();
- net_def_stack.push_back(nd);
- emitStatements(stmt.falseBranch());
- net_def_stack.pop_back();
- }
- }
- void emitWhile(const While& stmt) {
- std::string loop_var = fresh();
- emitConst(0, loop_var, "i"); // it needs a definition before loop
- auto op = cur().add_op();
- op->set_type("While");
- auto cond = op->add_arg();
- cond->set_name("cond_net");
- auto cond_net = cond->mutable_n();
-
- net_def_stack.push_back(cond_net);
- emit(stmt.cond(), {loop_var});
- net_def_stack.pop_back();
-
- op->add_input(loop_var);
- auto body = op->add_arg();
- body->set_name("loop_net");
- auto body_net = body->mutable_n();
-
- net_def_stack.push_back(body_net);
- emitStatements(stmt.body());
- net_def_stack.pop_back();
- }
- std::string getLHS(const TreeRef& tree) {
- switch (tree->kind()) {
- case TK_IDENT: {
- return Ident(tree).name();
- } break;
- case '.': {
- auto sel = Select(tree);
- std::string lhs = getValue(sel.value());
- // TODO: check whether this subname exists in object lhs
- return lhs + "/" + sel.selector().name();
- } break;
- default: {
- throw ErrorReport(tree)
- << "This expression cannot appear on the left-hand size of an assignment";
- } break;
- }
- }
- std::string getValue(const TreeRef& tree) {
- switch (tree->kind()) {
- case TK_IDENT: {
- return lookup(Ident(tree));
- } break;
- case '.': {
- auto sel = Select(tree);
- std::string lhs = getValue(sel.value());
- // TODO: check whether this subname exists in object lhs
- return lhs + "/" + sel.selector().name();
- } break;
- default: {
- std::string name = fresh();
- emit(tree, {name});
- return name;
- } break;
- }
- }
- std::string fresh(std::string prefix = "$t") {
- return std::string(prefix) + c10::to_string(next_fresh++);
- }
- const char* operatorName(int kind, int ninputs) {
- switch (kind) {
- case '+':
- return "Add";
- case '-':
- if (ninputs == 1)
- return "Negative";
- else
- return "Sub";
- case '*':
- return "Mul";
- case '/':
- return "Div";
- case TK_NE:
- return "NE";
- case TK_EQ:
- return "EQ";
- case '<':
- return "LT";
- case '>':
- return "GT";
- case TK_LE:
- return "LE";
- case TK_GE:
- return "GE";
- case TK_IF_EXPR:
- return "Conditional";
- case TK_AND:
- return "And";
- case TK_OR:
- return "Or";
- case TK_NOT:
- return "Not";
- default:
- throw std::runtime_error("unknown kind " + c10::to_string(kind));
- }
- }
- void fillArg(Argument* arg, const Attribute& attr) {
- std::string name = attr.name().name();
- arg->set_name(name);
- auto value = attr.value();
- // TODO: handle non-float attributes
- switch (value->kind()) {
- case TK_CONST: {
- auto v = value->tree(0)->doubleValue();
- auto f = value->tree(1)->stringValue();
- if (f == "f")
- arg->set_f(v);
- else
- arg->set_i(v);
- } break;
- case TK_LIST:
- for (auto t : value->trees()) {
- auto v = t->tree(0)->doubleValue();
- auto f = t->tree(1)->stringValue();
- if (f == "f")
- arg->add_floats(v);
- else
- arg->add_ints(v);
- }
- break;
- }
- }
- template <typename Trees>
- std::vector<std::string> getValues(const Trees& trees) {
- std::vector<std::string> result;
- for (const auto& tree : trees) {
- result.push_back(getValue(tree));
- }
- return result;
- }
-
- bool renameLookup(
- std::unordered_map<std::string, std::string>& rename_map,
- const std::string& name,
- std::string& rename) {
- // first look for name in the map directly
- auto it = rename_map.find(name);
- if (it != rename_map.end()) {
- rename = it->second;
- return true;
- }
- // otherwise if we have a rename entry like a => b and a name "a/foo/bar"
- // then replace it with "b/foo/bar"
- auto p = name.find("/");
- if (p == std::string::npos)
- return false;
- it = rename_map.find(name.substr(0, p));
- if (it != rename_map.end()) {
- rename = it->second + name.substr(p);
- return true;
- }
- return false;
- }
- void renameOp(
- std::unordered_map<std::string, std::string>& rename_map,
- const Apply& apply,
- const std::string& prefix,
- bool isExtern,
- OperatorDef* new_op) {
- for (size_t i = 0; i < new_op->input().size(); i++) {
- auto& name = new_op->input(i);
- std::string renamed;
- bool defined = renameLookup(rename_map, name, renamed);
- if (!isExtern && !defined) {
- throw ErrorReport(apply)
- << " unexpected undefined name '" << name
- << "' while attempting to inline '" << apply.name().name() << "'";
- } else if (!defined) {
- // extern function using a global name, assign it an identity mapping
- rename_map[name] = name;
- }
- new_op->set_input(i, renamed);
- }
- for (size_t i = 0; i < new_op->output().size(); i++) {
- auto& name = new_op->output(i);
- std::string renamed;
- if (!renameLookup(rename_map, name, renamed)) {
- renamed = prefix + name;
- rename_map[name] = renamed;
- }
- new_op->set_output(i, renamed);
- }
- // handle control flow inside the op as well
- if (ops_containing_nets.count(new_op->type()) > 0) {
- for (size_t i = 0; i < new_op->arg_size(); i++) {
- auto* arg = new_op->mutable_arg(i);
- if (arg->has_n()) {
- auto* n = arg->mutable_n();
- for (size_t j = 0; j < n->op_size(); j++) {
- renameOp(rename_map, apply, prefix, isExtern, n->mutable_op(j));
- }
- }
- }
- }
- }
-
- bool hasBypassRename(const Apply& apply) {
- for (auto attr : apply.attributes()) {
- if (attr.name().name() == "rename") {
- if (attr.value()->kind() != TK_CONST) {
- throw ErrorReport(attr.value()) << "expected a single constant";
- }
- return attr.value()->tree(0)->doubleValue() == 0;
- }
- }
- return false;
- }
-
- // emit a function call by inlining the function's NetDef into our
- // net def, renaming temporaries func_name<unique_id>/orig_name
- // renaming only happens for values defined by the function
- // that are not marked outputs
-
- // inputs/outputs are passed by reference
- void emitFunctionCall(Apply& apply, const std::vector<std::string>& outputs) {
- std::string fname = apply.name().name();
- std::string prefix = fresh(fname) + "/";
- auto& fn = symbol_table.at(apply.name().name());
- bool isExtern = fn.isExtern();
- auto inputs = getValues(apply.inputs());
- std::unordered_map<std::string, std::string> rename_map;
- if (inputs.size() != fn.inputs.size()) {
- throw ErrorReport(apply) << fname << " expected " << fn.inputs.size()
- << " values but received " << inputs.size();
- }
- for (size_t i = 0; i < inputs.size(); i++) {
- rename_map[fn.inputs[i]] = inputs[i];
- }
- if (outputs.size() != fn.outputs.size()) {
- throw ErrorReport(apply) << fname << " expected " << fn.outputs.size()
- << " values but received " << outputs.size();
- }
- for (size_t i = 0; i < outputs.size(); i++) {
- rename_map[fn.outputs[i]] = outputs[i];
- }
- for (auto& op : fn.net_def->op()) {
- auto new_op = cur().add_op();
- new_op->CopyFrom(op);
- if (hasBypassRename(apply)) {
- prefix = "";
- }
- renameOp(rename_map, apply, prefix, isExtern, new_op);
- }
- }
- void expectOutputs(
- const TreeRef& tree,
- const std::vector<std::string>& outputs,
- size_t size) {
- if (outputs.size() != size) {
- throw ErrorReport(tree)
- << "expected operator to produce " << outputs.size()
- << " outputs but it produced " << size;
- }
- }
- void appendOutputs(
- const TreeRef& tree,
- OperatorDef* op,
- const std::vector<std::string>& outputs,
- size_t size) {
- expectOutputs(tree, outputs, size);
- for (size_t i = 0; i < size; i++) {
- op->add_output(outputs[i]);
- }
- }
- void emitOperator(
- const Apply& apply,
- const OpSchema* schema,
- const std::vector<std::string>& outputs) {
- // must be before add_op
- auto values = getValues(apply.inputs());
- if (values.size() < schema->min_input() ||
- values.size() > schema->max_input()) {
- if (schema->min_input() == schema->max_input()) {
- throw ErrorReport(apply) << "operator expects " << schema->min_input()
- << " inputs but found " << values.size();
- } else {
- throw ErrorReport(apply)
- << "operator takes between " << schema->min_input() << " and "
- << schema->max_input() << " inputs but found " << values.size()
- << ".";
- }
- }
- auto numActualOutputs = schema->CalculateOutput(values.size());
- if (numActualOutputs != kCannotComputeNumOutputs &&
- outputs.size() != numActualOutputs) {
- throw ErrorReport(apply)
- << "operator produces " << numActualOutputs
- << " outputs but matched to " << outputs.size() << " outputs";
- }
- auto op = cur().add_op();
- op->set_type(apply.name().name());
- for (auto& v : values) {
- op->add_input(v);
- }
- // assume 1 output unless matched to more
- appendOutputs(apply, op, outputs, outputs.size());
- for (auto attribute : apply.attributes()) {
- fillArg(op->add_arg(), attribute);
- }
- // Ok, we checked the stuff where we can easily give a friendly error
- // message, now verify against the schema and report the error at the line
- if (!schema->Verify(*op)) {
- throw ErrorReport(apply) << "failed schema checking";
- }
- }
-
- // Emit an operation, writing results into 'outputs'.
- // This will _always_ compute something, unlike 'getValue' which simply
- // returns an already computed reference if possible.
- // So if 'tree' is an identifier or nested identifier (foo.bar)
- // this will cause it to be _copied_ into outputs.
- void emit(const TreeRef& tree, const std::vector<std::string>& outputs) {
- switch (tree->kind()) {
- case TK_IDENT:
- case '.': {
- auto op = cur().add_op();
- op->set_type("Copy");
- op->add_input(getValue(tree));
- appendOutputs(tree, op, outputs, 1);
- } break;
- case TK_NE:
- case TK_EQ:
- case '<':
- case '>':
- case TK_LE:
- case TK_GE:
- case '-':
- case '*':
- case '/':
- case '+':
- case TK_AND:
- case TK_OR:
- case TK_NOT:
- case TK_IF_EXPR: {
- // must be before add_op
- auto values = getValues(tree->trees());
- auto op = cur().add_op();
- op->set_type(operatorName(tree->kind(), tree->trees().size()));
- for (auto& v : values) {
- op->add_input(v);
- }
- appendOutputs(tree, op, outputs, 1);
- auto broadcast = op->add_arg();
- broadcast->set_name("broadcast");
- broadcast->set_i(1);
- } break;
- case TK_APPLY: {
- auto apply = Apply(tree);
- // Handle built-ins like zeros, ones, etc
- if (builtins.count(apply.name().name()) > 0) {
- builtins[apply.name().name()](this, apply, outputs);
- break;
- }
- if (symbol_table.count(apply.name().name()) > 0) {
- emitFunctionCall(apply, outputs);
- break;
- }
- auto schema = OpSchemaRegistry::Schema(apply.name().name());
- if (schema) {
- emitOperator(apply, schema, outputs);
- break;
- }
- throw ErrorReport(apply)
- << "attempting to call unknown operation or function '"
- << apply.name().name() << "'";
- } break;
- case TK_CAST: {
- auto cast = Cast(tree);
- auto c2type = getType(cast.type());
- auto input = getValue(cast.input());
- auto op = cur().add_op();
- op->set_type("Cast");
- op->add_input(input);
- appendOutputs(tree, op, outputs, 1);
- auto arg = op->add_arg();
- arg->set_name("to");
- arg->set_i(c2type);
- } break;
- case TK_CONST: {
- expectOutputs(tree, outputs, 1);
- emitConst(
- tree->tree(0)->doubleValue(),
- outputs[0],
- tree->tree(1)->stringValue());
- } break;
- case TK_GATHER: {
- const auto gather = Gather(tree);
- desugarAndEmitOperator(
- "Gather",
- gather.range(),
- {gather.value(), gather.indices()},
- outputs);
- break;
- }
- case TK_SLICE: {
- const auto slice = Slice(tree);
- desugarAndEmitOperator(
- "Slice",
- slice.range(),
- {slice.value(), slice.startOr(0), slice.endOr(-1)},
- outputs);
- break;
- }
- default:
- throw ErrorReport(tree) << "NYI: " << tree;
- break;
- }
- }
-
- // Desugars constructs that are syntactic sugar and emits the corresponding
- // operator invocation, e.g. tensor[indices] -> tensor.Gather(indices).
- void desugarAndEmitOperator(
- const std::string& operatorName,
- const SourceRange& range,
- TreeList&& inputs,
- const std::vector<std::string>& outputs) {
- const auto applyName = Ident::create(range, operatorName);
- const auto applyInputs =
- Compound::create(TK_LIST, range, std::move(inputs));
- const auto applyAttributes = Compound::create(TK_LIST, range, {});
- const auto apply =
- Apply::create(range, applyName, applyInputs, applyAttributes);
- const auto schema = OpSchemaRegistry::Schema(operatorName);
- assert(schema != nullptr);
- emitOperator(Apply(apply), schema, outputs);
- }
-
- TensorProto_DataType getType(int type) {
- switch (type) {
- case TK_INT:
- return TensorProto_DataType_INT32;
- case TK_FLOAT:
- return TensorProto_DataType_FLOAT;
- case TK_LONG:
- return TensorProto_DataType_INT64;
- case TK_BOOL:
- return TensorProto_DataType_BOOL;
- default:
- throw std::runtime_error(
- "expected type token: " + c10::to_string(type));
- }
- }
-
- OperatorDef* emitConst(
- double v,
- const std::string& output,
- const std::string& type_ident) {
- auto op = cur().add_op();
- op->set_type("ConstantFill");
- auto dtype = op->add_arg();
- dtype->set_name("dtype");
- auto value = op->add_arg();
- value->set_name("value");
- if (type_ident == "f") {
- dtype->set_i(TensorProto_DataType_FLOAT);
- value->set_f(v);
- } else if (type_ident == "LL") {
- dtype->set_i(TensorProto_DataType_INT64);
- value->set_i(v);
- } else if (type_ident == "b") {
- dtype->set_i(TensorProto_DataType_BOOL);
- value->set_i(v != 0);
- } else if (type_ident == "i") {
- dtype->set_i(TensorProto_DataType_INT32);
- value->set_i(v);
- } else {
- throw std::runtime_error("unknown type_ident " + type_ident);
- }
- auto shape = op->add_arg();
- shape->set_name("shape");
- shape->add_ints(1);
- op->add_output(output);
- return op;
- }
- NetDef& cur() {
- return *net_def_stack.back();
- }
- FunctionDefinition& def; // the def being constructed
- std::unordered_map<std::string, std::string>
- env; // map from name in Def to name in NetDef
- std::vector<NetDef*> net_def_stack;
- SymbolTable& symbol_table;
- int next_fresh = 0;
-
- private:
- void emitFillOp(const Apply& apply, const std::vector<std::string>& outputs) {
- auto builtin_type = apply.name().name();
- auto values = getValues(apply.inputs());
- if (values.size() > 1) {
- throw ErrorReport(apply)
- << "Built-in " << builtin_type << " accepts 0 or 1 inputs.";
- }
- bool has_shape = false;
- for (const auto& attribute : apply.attributes()) {
- if (attribute.name().name() == "shape") {
- has_shape = true;
- } else {
- throw ErrorReport(apply)
- << "Unrecognized attribute " << attribute.name().name()
- << " for built-in " << builtin_type;
- }
- }
- if (builtin_type == "zeros" || builtin_type == "ones") {
- if ((values.size() != 1) && !has_shape) {
- throw ErrorReport(apply)
- << "Built-in " << builtin_type
- << " requires either 1 input or 1 shape attribute";
- }
- } else {
- // zeros_like or ones_like
- if (values.size() != 1) {
- throw ErrorReport(apply)
- << "Built-in " << builtin_type << " requires 1 input";
- }
- }
-
- auto op = cur().add_op();
- op->set_type("ConstantFill");
- if (values.size()) {
- op->add_input(values[0]);
- auto* input_as_shape = op->add_arg();
- input_as_shape->set_name("input_as_shape");
- if (builtin_type.find("_like") != std::string::npos) {
- // zeros_like, ones_like take the shape of the input as constant
- // tensor shape
- input_as_shape->set_i(0);
- } else {
- // zeros, ones take the values in the tensor as constant tensor
- // shape
- input_as_shape->set_i(1);
- }
- } else {
- fillArg(op->add_arg(), apply.attributes()[0]);
- }
-
- auto value = op->add_arg();
- value->set_name("value");
- if (builtin_type.find("ones") != std::string::npos) {
- value->set_f(1.0f);
- } else {
- value->set_f(0.0f);
- }
- appendOutputs(apply, op, outputs, 1);
- }
- // emitModule doesn't actually do anything except for allow
- // statements like a = Module() to register 'a' as a valid identifier
- // so that a.b = ... will work
- void emitModule(const Apply& apply, const std::vector<std::string>& outputs) {
- expectOutputs(apply, outputs, 1);
- }
- std::unordered_map<
- std::string,
- std::function<void(
- DefCompiler*,
- const Apply&,
- const std::vector<std::string>& outputs)>>
- builtins{{"zeros", &DefCompiler::emitFillOp},
- {"zeros_like", &DefCompiler::emitFillOp},
- {"ones", &DefCompiler::emitFillOp},
- {"ones_like", &DefCompiler::emitFillOp},
- {"Module", &DefCompiler::emitModule}};
-};
-
-struct CompilationUnitImpl {
- void defineFunction(const Def& def) {
- if (functions.count(def.name().name()) > 0) {
- throw ErrorReport(def) << def.name().name() << " already defined.";
- }
- DefCompiler c(
- functions.emplace(def.name().name(), FunctionDefinition(def))
- .first->second,
- functions);
- c.run();
- }
-
- void define(const std::string& str) {
- Parser p(str);
- while (p.lexer().cur().kind != TK_EOF) {
- defineFunction(Def(p.parseFunction()));
- }
- }
-
- std::unique_ptr<NetBase> createNet(Workspace* ws, const std::string& str) {
- if (functions.count(str) == 0)
- throw ErrorReport() << "undefined function: " << str << "\n";
- auto& def = functions.at(str);
- return caffe2::CreateNet(*def.net_def, ws);
- }
-
- void defineExtern(const std::string& name, std::unique_ptr<NetDef> net_def) {
- // TODO: unify extern and function namespaces
- if (functions.count(name) > 0) {
- throw ErrorReport() << "function '" << name << "' already defined.";
- }
- functions.emplace(name, FunctionDefinition(std::move(net_def)));
- }
-
- std::string getProto(const std::string& functionName) {
- return functions.at(functionName).net_def->DebugString();
- }
-
- private:
- friend struct DefCompiler;
- SymbolTable functions;
-};
-
-CompilationUnit::CompilationUnit() : pImpl(new CompilationUnitImpl()) {}
-
-void CompilationUnit::define(const std::string& str) {
- return pImpl->define(str);
-}
-
-void CompilationUnit::defineExtern(
- const std::string& name,
- std::unique_ptr<NetDef> nd) {
- pImpl->defineExtern(name, std::move(nd));
-}
-
-std::unique_ptr<NetBase> CompilationUnit::createNet(
- Workspace* ws,
- const std::string& str) {
- return pImpl->createNet(ws, str);
-}
-
-std::string CompilationUnit::getProto(const std::string& functionName) const {
- return pImpl->getProto(functionName);
-}
-
-CompilationUnit::~CompilationUnit() {}
-
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/compiler.h b/caffe2/contrib/script/compiler.h
deleted file mode 100644
index 0a15c33..0000000
--- a/caffe2/contrib/script/compiler.h
+++ /dev/null
@@ -1,24 +0,0 @@
-#pragma once
-#include <memory>
-#include <string>
-#include "caffe2/core/net.h"
-
-namespace caffe2 {
-namespace script {
-
-struct CompilationUnitImpl;
-
-struct CAFFE2_API CompilationUnit {
- CompilationUnit();
- void define(const std::string& str);
- void defineExtern(const std::string& str, std::unique_ptr<NetDef> netdef);
- std::unique_ptr<NetBase> createNet(Workspace* ws, const std::string& name);
- std::string getProto(const std::string& functionName) const;
- ~CompilationUnit();
-
- private:
- std::unique_ptr<CompilationUnitImpl> pImpl;
-};
-
-} // namespace script
-}; // namespace caffe2
diff --git a/caffe2/contrib/script/error_report.h b/caffe2/contrib/script/error_report.h
deleted file mode 100644
index cecc0f3..0000000
--- a/caffe2/contrib/script/error_report.h
+++ /dev/null
@@ -1,51 +0,0 @@
-#pragma once
-
-#include "caffe2/contrib/script/tree.h"
-
-namespace caffe2 {
-namespace script {
-
-struct ErrorReport : public std::exception {
- ErrorReport(const ErrorReport& e)
- : ss(e.ss.str()), context(e.context), the_message(e.the_message) {}
-
- ErrorReport() : context(nullptr) {}
- explicit ErrorReport(const SourceRange& r)
- : context(std::make_shared<SourceRange>(r)) {}
- explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
- explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
- virtual const char* what() const noexcept override {
- std::stringstream msg;
- msg << "\n" << ss.str();
- if (context != nullptr) {
- msg << ":\n";
- context->highlight(msg);
- } else {
- msg << ".\n";
- }
- the_message = msg.str();
- return the_message.c_str();
- }
-
- private:
- template <typename T>
- friend const ErrorReport& operator<<(const ErrorReport& e, const T& t);
-
- mutable std::stringstream ss;
- std::shared_ptr<SourceRange> context;
- mutable std::string the_message;
-};
-
-template <typename T>
-const ErrorReport& operator<<(const ErrorReport& e, const T& t) {
- e.ss << t;
- return e;
-}
-
-#define C2S_ASSERT(ctx, cond) \
- if (!(cond)) { \
- throw ::caffe2::script::ErrorReport(ctx) \
- << __FILE__ << ":" << __LINE__ << ": assertion failed: " << #cond; \
- }
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/examples/example_beam_search.c2s b/caffe2/contrib/script/examples/example_beam_search.c2s
deleted file mode 100644
index 2e081ee..0000000
--- a/caffe2/contrib/script/examples/example_beam_search.c2s
+++ /dev/null
@@ -1,76 +0,0 @@
-[["log_probs", [6, 1, 44463], "float32"], ["attentions", [6, 1, 21], "float32"], ["inputs", [21], "float32"]]
-beam_search
-["scores_t"]
-
-def beam_search(inputs, log_probs, attentions) -> ():
- beam_size = 6LL
- length = 20LL
- beam_output_shape, _ = Concat(length + 1LL, beam_size, axis=0)
- output_token_beam_list = int(zeros(beam_output_shape))
- output_prev_index_beam_list = int(zeros(beam_output_shape))
- output_score_beam_list = zeros(beam_output_shape)
-
- input_length = inputs.Size().ExpandDims(dims=[0])
-
- attention_beam_output_shape, _ = Concat(
- input_length, beam_output_shape, axis=0)
- output_attention_weights_beam_list = zeros(attention_beam_output_shape)
-
- attention_step_output_shape, _ = Concat(beam_size, input_length, axis=0)
- attention_t = zeros(attention_step_output_shape)
-
- scores_t = zeros(shape=[1, 6])
- hypo_t = int(zeros(shape=[6]))
- tokens_t = int(ones(shape=[6])) * 99
-
- output_token_beam_list = output_token_beam_list.ScatterAssign(0, tokens_t)
- output_token_beam_list = output_token_beam_list.ExpandDims(dims=[2])
- output_prev_index_beam_list = output_prev_index_beam_list.ScatterAssign(
- 0, hypo_t)
- output_prev_index_beam_list = output_prev_index_beam_list.ExpandDims(dims=[2])
- output_score_beam_list = output_score_beam_list.ScatterAssign(0, scores_t)
- output_score_beam_list = output_score_beam_list.ExpandDims(dims=[2])
- output_attention_weights_beam_list = output_attention_weights_beam_list\
- .ScatterAssign(0, attention_t)
-
- length_32 = int(length)
-
- timestep = 0
- not_finished = True
- while not_finished:
- # TODO: once we have a metaprogramming facility we need to insert the
- # body of the post_eos_penalty here programmatically
-
- best_scores_per_hypo, best_tokens_per_hypo = log_probs.TopK(k=6)
-
- # Add the best score in each hypothesis to the cumulative score so far
- output_scores = best_scores_per_hypo + scores_t.Squeeze(dims=[0])
-
- # Flatten scores so we can find the best overall out of all hypotheses
- output_scores_flattened_slice, _ = output_scores.FlattenToVec()\
- .Slice(0, 6 if timestep == 0 else -1).Reshape(shape=[1, -1])
-
- # Find top K out of all
- scores_t, best_indices = output_scores_flattened_slice.TopK(k=6)
-
- # Integer floor divide on indices finds the association back to original
- # hypotheses. Use this to reorder states
- hypo_t_int64 = best_indices / 6LL
-
- # Reorder attentions
- attention_t, _ = attentions.Gather(hypo_t_int64)\
- .Reshape(shape=[1, 6, -1])
- tokens_t_int64 = best_tokens_per_hypo.FlattenToVec()\
- .Gather(best_indices).Cast(to=2)
-
- timestep += 1
- not_finished = timestep < length_32
-
- output_token_beam_list = output_token_beam_list\
- .ScatterAssign(timestep, tokens_t)
- output_prev_index_beam_list = output_prev_index_beam_list\
- .ScatterAssign(timestep, hypo_t)
- output_score_beam_list = output_score_beam_list\
- .ScatterAssign(timestep, scores_t)
- output_attention_weights_beam_list = output_attention_weights_beam_list\
- .ScatterAssign(timestep, attention_t)
diff --git a/caffe2/contrib/script/examples/example_post_eos_penalty.c2s b/caffe2/contrib/script/examples/example_post_eos_penalty.c2s
deleted file mode 100644
index 9988913..0000000
--- a/caffe2/contrib/script/examples/example_post_eos_penalty.c2s
+++ /dev/null
@@ -1,13 +0,0 @@
-[["tokens_t", [1, 6], "int32"], ["hypo_t", [1, 6], "int32"], ["log_probs", [6, 1, 44463], "float32"], ["on_initial_step", [1], "bool_"]]
-post_eos_penalty
-["log_probs"]
-
-def post_eos_penalty(tokens_t, hypo_t, log_probs, on_initial_step) \
- -> (log_probs):
- eos_token = 1
- finished_penalty = 0f if on_initial_step else 0.5f
- predecessor_tokens = tokens_t.FlattenToVec().Gather(hypo_t.FlattenToVec())
- predecessor_is_eos = float(predecessor_tokens == eos_token)
- log_probs = log_probs.Add(
- predecessor_is_eos * finished_penalty, broadcast=1, axis=0
- )
diff --git a/caffe2/contrib/script/examples/run_examples.py b/caffe2/contrib/script/examples/run_examples.py
deleted file mode 100644
index 26f2db0..0000000
--- a/caffe2/contrib/script/examples/run_examples.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-from caffe2.python import core, workspace
-import glob
-import json
-import numpy as np
-
-example_files = glob.glob('example_*.c2s')
-
-for ex in example_files:
- print('Running example file', ex)
- with open(ex, 'r') as f:
- inits = json.loads(f.readline())
- net_name = f.readline().strip()
- outputs = json.loads(f.readline())
-
- CU = core.C.CompilationUnit()
- CU.define(f.read())
-
- # Initialize workspace with required inputs
- for name, shape, dt in inits:
- workspace.FeedBlob(name, np.random.rand(*shape).astype(np.dtype(dt)))
-
- net = CU.create_net(net_name)
- net.run()
-
- print('Success! Interesting outputs:')
- for output in outputs:
- print(output, workspace.FetchBlob(output))
diff --git a/caffe2/contrib/script/lexer.cc b/caffe2/contrib/script/lexer.cc
deleted file mode 100644
index 9dafea9..0000000
--- a/caffe2/contrib/script/lexer.cc
+++ /dev/null
@@ -1,26 +0,0 @@
-#include "caffe2/contrib/script/lexer.h"
-#include "caffe2/core/common.h"
-
-namespace caffe2 {
-namespace script {
-
-std::string kindToString(int kind) {
- if (kind < 256)
- return std::string(1, kind);
- switch (kind) {
-#define DEFINE_CASE(tok, str, _) \
- case tok: \
- return str;
- TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
-#undef DEFINE_CASE
- default:
- throw std::runtime_error("unknown kind: " + c10::to_string(kind));
- }
-}
-
-SharedParserData& sharedParserData() {
- static SharedParserData data; // safely handles multi-threaded init
- return data;
-}
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/lexer.h b/caffe2/contrib/script/lexer.h
deleted file mode 100644
index b298809..0000000
--- a/caffe2/contrib/script/lexer.h
+++ /dev/null
@@ -1,527 +0,0 @@
-#pragma once
-#include <assert.h>
-#include <algorithm>
-#include <iostream>
-#include <memory>
-#include <sstream>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "caffe2/core/common.h"
-
-namespace caffe2 {
-namespace script {
-
-// single character tokens are just the character itself '+'
-// multi-character tokens need an entry here
-// if the third entry is not the empty string, it is used
-// in the lexer to match this token.
-
-// These kinds are also used in Tree.h as the kind of the AST node.
-// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
-// lexer.
-
-#define TC_FORALL_TOKEN_KINDS(_) \
- _(TK_EOF, "eof", "") \
- _(TK_WHITESPACE, "whitespace", "") \
- _(TK_NUMBER, "number", "") \
- _(TK_NEWLINE, "newline", "") \
- _(TK_INDENT, "indent", "") \
- _(TK_DEDENT, "dedent", "") \
- _(TK_WHERE, "where", "where") \
- _(TK_FLOAT, "float", "float") \
- _(TK_DOUBLE, "double", "double") \
- _(TK_LONG, "long", "long") \
- _(TK_INT, "int", "int") \
- _(TK_DEF, "def", "def") \
- _(TK_ARROW, "arrow", "->") \
- _(TK_EQUIVALENT, "equivalent", "<=>") \
- _(TK_IDENT, "ident", "") \
- _(TK_STRING, "string", "") \
- _(TK_CONST, "const", "") \
- _(TK_LIST, "list", "") \
- _(TK_OPTION, "option", "") \
- _(TK_APPLY, "apply", "") \
- _(TK_COMPREHENSION, "comprehension", "") \
- _(TK_TENSOR_TYPE, "tensor_type", "") \
- _(TK_RANGE_CONSTRAINT, "range_constraint", "") \
- _(TK_PARAM, "param", "") \
- _(TK_INFERRED, "inferred", "") \
- _(TK_BOOL, "bool", "") \
- _(TK_ACCESS, "access", "") \
- _(TK_ASSIGN, "assign", "") \
- _(TK_ATTRIBUTE, "attribute", "") \
- _(TK_IF, "if", "if") \
- _(TK_ELSE, "else", "else") \
- _(TK_ELIF, "elif", "elif") \
- _(TK_WHILE, "while", "while") \
- _(TK_NE, "ne", "!=") \
- _(TK_EQ, "eq", "==") \
- _(TK_LE, "le", "<=") \
- _(TK_GE, "ge", ">=") \
- _(TK_IF_EXPR, "if", "") \
- _(TK_TRUE, "True", "True") \
- _(TK_FALSE, "False", "False") \
- _(TK_AND, "and", "and") \
- _(TK_OR, "or", "or") \
- _(TK_NOT, "not", "not") \
- _(TK_CAST, "cast", "") \
- _(TK_PLUS_EQ, "+=", "+=") \
- _(TK_MINUS_EQ, "-=", "-=") \
- _(TK_TIMES_EQ, "*=", "*=") \
- _(TK_DIV_EQ, "/=", "/=") \
- _(TK_GLOBAL, "global", "global") \
- _(TK_BUILT_IN, "built-in", "") \
- _(TK_SLICE, "slice", "") \
- _(TK_GATHER, "gather", "")
-static const char* valid_single_char_tokens = "+-*/()[]:,={}><.";
-
-enum TokenKind {
- // we use characters to represent themselves so skip all valid characters
- // before
- // assigning enum values to multi-char tokens.
- TK_DUMMY_START = 256,
-#define DEFINE_TOKEN(tok, _, _2) tok,
- TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
-#undef DEFINE_TOKEN
-};
-
-std::string kindToString(int kind);
-
-// nested hash tables that indicate char-by-char what is a valid token.
-struct TokenTrie;
-using TokenTrieRef = std::unique_ptr<TokenTrie>;
-struct TokenTrie {
- TokenTrie() : kind(0) {}
- void insert(const char* str, int tok) {
- if (*str == '\0') {
- assert(kind == 0);
- kind = tok;
- return;
- }
- auto& entry = children[*str];
- if (entry == nullptr) {
- entry.reset(new TokenTrie());
- }
- entry->insert(str + 1, tok);
- }
- int kind; // 0 == invalid token
- std::unordered_map<char, TokenTrieRef> children;
-};
-
-// stuff that is shared against all TC lexers/parsers and is initialized only
-// once.
-struct SharedParserData {
- SharedParserData() : head(new TokenTrie()) {
- // listed in increasing order of precedence
- std::vector<std::vector<int>> binary_ops = {
- {TK_IF},
- {TK_AND, TK_OR},
- {}, // reserve a level for unary not
- {'<', '>', TK_EQ, TK_LE, TK_GE, TK_NE},
- {'+', '-'},
- {'*', '/'},
- };
- std::vector<std::vector<int>> unary_ops = {
- {'-'},
- };
-
- std::stringstream ss;
- for (const char* c = valid_single_char_tokens; *c; c++) {
- const char str[] = {*c, '\0'};
- head->insert(str, *c);
- }
-
-#define ADD_CASE(tok, _, tokstring) \
- if (*tokstring != '\0') { \
- head->insert(tokstring, tok); \
- }
- TC_FORALL_TOKEN_KINDS(ADD_CASE)
-#undef ADD_CASE
-
- // precedence starts at 1 so that there is always a 0 precedence
- // less than any other precedence
- int prec = 1;
- for (auto& group : binary_ops) {
- for (auto& element : group) {
- binary_prec[element] = prec;
- }
- prec++;
- }
- // unary ops
- for (auto& group : unary_ops) {
- for (auto& element : group) {
- unary_prec[element] = prec;
- }
- prec++;
- }
- // add unary not separately because it slots into the precedence of
- // binary operators
- unary_prec[TK_NOT] = binary_prec[TK_AND] + 1;
- }
- // 1. skip whitespace
- // 2. handle comment or newline
- //
- bool isNumber(const std::string& str, size_t start, size_t* len) {
- char first = str[start];
- // strtod allows numbers to start with + or - or nan or inf
- // http://en.cppreference.com/w/cpp/string/byte/strtof
- // but we want only the number part, otherwise 1+3 will turn into two
- // adjacent numbers in the lexer
- if (first == '-' || first == '+' || isalpha(first))
- return false;
- const char* startptr = str.c_str() + start;
- char* endptr;
- std::strtod(startptr, &endptr);
- *len = endptr - startptr;
- return *len > 0;
- }
- bool isblank(int n) {
- return isspace(n) && n != '\n';
- }
- // find the longest match of str.substring(pos) against a token, return true
- // if successful
- // filling in kind, start,and len
- bool match(
- const std::string& str,
- size_t pos,
- bool continuation, // are we inside a scope where newlines don't count
- // (e.g. inside parens)
- bool whitespace_token, // should we treat whitespace as a token
- int* kind,
- size_t* start,
- size_t* len) {
- *start = pos;
- // skip whitespace
- while (pos < str.size() && isblank(str[pos]))
- pos++;
-
- // special handling
- if (pos < str.size()) {
- if (str[pos] == '#') {
- // skip comments
- while (pos < str.size() && str[pos] != '\n')
- pos++;
- // tail call, handle whitespace and more comments
- return match(
- str, pos, continuation, whitespace_token, kind, start, len);
- }
- if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' &&
- !whitespace_token) {
- return match(str, pos + 2, continuation, false, kind, start, len);
- }
- if (str[pos] == '\n') {
- return match(
- str, pos + 1, continuation, !continuation, kind, start, len);
- }
- }
- if (pos == str.size()) {
- *kind = TK_EOF;
- *start = pos;
- *len = 0;
- return true;
- }
- // invariant: the next token is not whitespace or newline
- if (whitespace_token) {
- *kind = TK_WHITESPACE;
- *len = pos - *start;
- return true;
- }
- *start = pos;
- // check for a valid number
- if (isNumber(str, pos, len)) {
- *kind = TK_NUMBER;
- return true;
- }
- // check for either an ident or a token
- // ident tracks whether what we have scanned so far could be an identifier
- // matched indicates if we have found any match.
- bool matched = false;
- bool ident = true;
- TokenTrie* cur = head.get();
- for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
- ident = ident && validIdent(i, str[pos + i]);
- if (ident) {
- matched = true;
- *len = i + 1;
- *kind = TK_IDENT;
- }
- // check for token second, so that e.g. 'max' matches the token TK_MAX
- // rather the
- // identifier 'max'
- if (cur) {
- auto it = cur->children.find(str[pos + i]);
- cur = (it == cur->children.end()) ? nullptr : it->second.get();
- if (cur && cur->kind != 0) {
- matched = true;
- *len = i + 1;
- *kind = cur->kind;
- }
- }
- }
- return matched;
- }
- bool isUnary(int kind, int* prec) {
- auto it = unary_prec.find(kind);
- if (it != unary_prec.end()) {
- *prec = it->second;
- return true;
- }
- return false;
- }
- bool isBinary(int kind, int* prec) {
- auto it = binary_prec.find(kind);
- if (it != binary_prec.end()) {
- *prec = it->second;
- return true;
- }
- return false;
- }
- bool isRightAssociative(int kind) {
- switch (kind) {
- case '?':
- return true;
- default:
- return false;
- }
- }
-
- private:
- bool validIdent(size_t i, char n) {
- return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
- }
- TokenTrieRef head;
- std::unordered_map<int, int>
- unary_prec; // map from token to its unary precedence
- std::unordered_map<int, int>
- binary_prec; // map from token to its binary precedence
-};
-
-SharedParserData& sharedParserData();
-
-// a range of a shared string 'file_' with functions to help debug by highlight
-// that
-// range.
-struct SourceRange {
- SourceRange(
- const std::shared_ptr<std::string>& file_,
- size_t start_,
- size_t end_)
- : file_(file_), start_(start_), end_(end_) {}
- const std::string text() const {
- return file().substr(start(), end() - start());
- }
- size_t size() const {
- return end() - start();
- }
- void highlight(std::ostream& out) const {
- const std::string& str = file();
- size_t begin = start();
- size_t end = start();
- while (begin > 0 && str[begin - 1] != '\n')
- --begin;
- while (end < str.size() && str[end] != '\n')
- ++end;
- out << str.substr(0, end) << "\n";
- out << std::string(start() - begin, ' ');
- size_t len = std::min(size(), end - start());
- out << std::string(len, '~')
- << (len < size() ? "... <--- HERE" : " <--- HERE");
- out << str.substr(end);
- if (str.size() > 0 && str.back() != '\n')
- out << "\n";
- }
- const std::string& file() const {
- return *file_;
- }
- const std::shared_ptr<std::string>& file_ptr() const {
- return file_;
- }
- size_t start() const {
- return start_;
- }
- size_t end() const {
- return end_;
- }
-
- private:
- std::shared_ptr<std::string> file_;
- size_t start_;
- size_t end_;
-};
-
-struct Token {
- int kind;
- SourceRange range;
- Token(int kind, const SourceRange& range) : kind(kind), range(range) {}
- double doubleValue() {
- assert(TK_NUMBER == kind);
- size_t idx;
- double r = ::c10::stod(text(), &idx);
- assert(idx == range.size());
- return r;
- }
- std::string text() {
- return range.text();
- }
- std::string kindString() const {
- return kindToString(kind);
- }
-};
-
-struct Lookahead {
- Lookahead(const Token& t) : t(t) {}
- Token t;
- bool valid = false;
- size_t repeat = 0;
-};
-
-struct Lexer {
- std::shared_ptr<std::string> file;
- explicit Lexer(const std::string& str)
- : file(std::make_shared<std::string>(str)),
- pos(0),
- cur_(TK_EOF, SourceRange(file, 0, 0)),
- lookahead_(cur_),
- repeat(0),
- nesting(0),
- shared(sharedParserData()) {
- auto first_indent = lexRaw(true);
- indent_stack.push_back(first_indent.range.size());
- next();
- }
- Token next() {
- Token r = cur_;
- if (repeat > 0) {
- repeat--;
- } else if (lookahead_.valid) {
- lookahead_.valid = false;
- repeat = lookahead_.repeat;
- cur_ = lookahead_.t;
- } else {
- std::tie(cur_, repeat) = lex();
- }
- return r;
- }
- bool nextIf(int kind) {
- if (cur_.kind != kind)
- return false;
- next();
- return true;
- }
-
- [[noreturn]] void reportError(const std::string& what) {
- reportError(what, cur_);
- }
- [[noreturn]] void reportError(const std::string& what, const Token& t) {
- std::stringstream ss;
- ss << what << ":\n";
- t.range.highlight(ss);
- throw std::runtime_error(ss.str());
- }
- [[noreturn]] void expected(const std::string& what, const Token& t) {
- std::stringstream ss;
- ss << "expected " << what << " but found '" << t.kindString()
- << "' here:\n";
- t.range.highlight(ss);
- throw std::runtime_error(ss.str());
- }
- [[noreturn]] void expected(const std::string& what) {
- expected(what, cur_);
- }
- Token expect(int kind) {
- if (cur_.kind != kind) {
- expected(kindToString(kind));
- }
- return next();
- }
- Token& lookahead() {
- if (!lookahead_.valid) {
- lookahead_.valid = true;
- std::tie(lookahead_.t, lookahead_.repeat) = lex();
- }
- return lookahead_.t;
- }
- Token& cur() {
- return cur_;
- }
-
- private:
- // token, number of times to repeat it
- std::pair<Token, int> lex() {
- auto r = lexRaw();
- int repeat = 0;
- switch (r.kind) {
- case '(':
- case '[':
- case '{':
- nesting++;
- break;
- case ')':
- case ']':
- case '}':
- nesting--;
- break;
- case TK_WHITESPACE: {
- size_t depth = r.range.size();
- if (depth > indent_stack.back()) {
- indent_stack.push_back(depth);
- r.kind = TK_INDENT;
- } else if (depth == indent_stack.back()) {
- r.kind = TK_NEWLINE;
- } else {
- while (indent_stack.back() != depth) {
- indent_stack.pop_back();
- repeat++;
- if (indent_stack.size() == 0) {
- reportError("invalid ident level", r);
- }
- }
- repeat--; // first repeat is this return
- r.kind = TK_DEDENT;
- }
- } break;
- case TK_EOF:
- if (indent_stack.size() > 1) {
- r.kind = TK_DEDENT;
- indent_stack.pop_back();
- }
- break;
- default:
- break;
- }
- return std::make_pair(r, repeat);
- }
- Token lexRaw(bool whitespace_token = false) {
- int kind;
- size_t start;
- size_t length;
- assert(file);
- if (!shared.match(
- *file,
- pos,
- nesting > 0,
- whitespace_token,
- &kind,
- &start,
- &length)) {
- expected(
- "a valid token",
- Token((*file)[start], SourceRange(file, start, start + 1)));
- }
- auto t = Token(kind, SourceRange(file, start, start + length));
- pos = start + length;
- return t;
- }
- size_t pos;
- Token cur_;
- Lookahead lookahead_;
- size_t repeat; // how many times to repeat the current token until we continue
-
- size_t nesting; // depth of ( [ { nesting...
- std::vector<int> indent_stack; // stack of identation level of blocks
- SharedParserData& shared;
-};
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/parser.h b/caffe2/contrib/script/parser.h
deleted file mode 100644
index 4b68b8d..0000000
--- a/caffe2/contrib/script/parser.h
+++ /dev/null
@@ -1,418 +0,0 @@
-#pragma once
-#include "lexer.h"
-#include "tree.h"
-#include "tree_views.h"
-
-namespace caffe2 {
-namespace script {
-
-struct Parser {
- explicit Parser(const std::string& str)
- : L(str), shared(sharedParserData()) {}
-
- TreeRef parseIdent() {
- auto t = L.expect(TK_IDENT);
- // whenever we parse something that has a TreeView type we always
- // use its create method so that the accessors and the constructor
- // of the Compound tree are in the same place.
- return Ident::create(t.range, t.text());
- }
- TreeRef createApply(TreeRef ident, TreeList& inputs) {
- TreeList attributes;
- auto range = L.cur().range;
- parseOperatorArguments(inputs, attributes);
- return Apply::create(
- range,
- ident,
- List(range, std::move(inputs)),
- List(range, std::move(attributes)));
- }
- // things like a 1.0 or a(4) that are not unary/binary expressions
- // and have higher precedence than all of them
- TreeRef parseBaseExp() {
- TreeRef prefix;
- switch (L.cur().kind) {
- case TK_NUMBER:
- case TK_TRUE:
- case TK_FALSE: {
- prefix = parseConst();
- } break;
- case '(': {
- L.next();
- prefix = parseExp();
- L.expect(')');
- } break;
- case TK_FLOAT:
- case TK_INT:
- case TK_LONG: {
- auto r = L.cur().range;
- auto type = c(L.next().kind, r, {});
- L.expect('(');
- auto exp = parseExp();
- L.expect(')');
- prefix = Cast::create(r, type, exp);
- } break;
- default: {
- prefix = parseIdent();
- if (L.cur().kind == '(') {
- TreeList inputs;
- prefix = createApply(prefix, inputs);
- }
- } break;
- }
- while (true) {
- if (L.nextIf('.')) {
- const auto name = parseIdent();
- if (L.cur().kind == '(') {
- TreeList inputs = {prefix};
- prefix = createApply(name, inputs);
- } else {
- prefix = Select::create(name->range(), prefix, name);
- }
- } else if (L.cur().kind == '[') {
- prefix = parseSliceOrGather(prefix);
- } else {
- break;
- }
- }
- return prefix;
- }
- TreeRef parseOptionalReduction() {
- auto r = L.cur().range;
- switch (L.cur().kind) {
- case TK_PLUS_EQ:
- case TK_MINUS_EQ:
- case TK_TIMES_EQ:
- case TK_DIV_EQ: {
- int modifier = L.next().text()[0];
- return c(modifier, r, {});
- } break;
- default: {
- L.expect('=');
- return c('=', r, {}); // no reduction
- } break;
- }
- }
- TreeRef
- parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) {
- auto cond = parseExp();
- L.expect(TK_ELSE);
- auto false_branch = parseExp(binary_prec);
- return c(TK_IF_EXPR, range, {cond, true_branch, false_branch});
- }
- // parse the longest expression whose binary operators have
- // precedence strictly greater than 'precedence'
- // precedence == 0 will parse _all_ expressions
- // this is the core loop of 'top-down precedence parsing'
- TreeRef parseExp(int precedence = 0) {
- TreeRef prefix = nullptr;
- int unary_prec;
- if (shared.isUnary(L.cur().kind, &unary_prec)) {
- auto kind = L.cur().kind;
- auto pos = L.cur().range;
- L.next();
- prefix = c(kind, pos, {parseExp(unary_prec)});
- } else {
- prefix = parseBaseExp();
- }
- int binary_prec;
- while (shared.isBinary(L.cur().kind, &binary_prec)) {
- if (binary_prec <= precedence) // not allowed to parse something which is
- // not greater than 'precedenc'
- break;
-
- int kind = L.cur().kind;
- auto pos = L.cur().range;
- L.next();
- if (shared.isRightAssociative(kind))
- binary_prec--;
-
- // special case for trinary operator
- if (kind == TK_IF) {
- prefix = parseTrinary(prefix, pos, binary_prec);
- continue;
- }
-
- prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
- }
- return prefix;
- }
- TreeRef
- parseList(int begin, int sep, int end, std::function<TreeRef(int)> parse) {
- auto r = L.cur().range;
- L.expect(begin);
- TreeList elements;
- if (L.cur().kind != end) {
- int i = 0;
- do {
- elements.push_back(parse(i++));
- } while (L.nextIf(sep));
- }
- L.expect(end);
- return c(TK_LIST, r, std::move(elements));
- }
- TreeRef parseNonEmptyList(int sep, std::function<TreeRef(int)> parse) {
- TreeList elements;
- int i = 0;
- do {
- elements.push_back(parse(i++));
- } while (L.nextIf(sep));
- return c(TK_LIST, elements[0]->range(), std::move(elements));
- }
- TreeRef parseExpList() {
- return parseList('(', ',', ')', [&](int i) { return parseExp(); });
- }
- TreeRef parseConst() {
- // 'b' - boolean
- // 'LL' 64-bit integer
- // 'f' single-precision float
- // 'i' 32-bit integer
- // 'f' is default if '.' appears in the number
- auto range = L.cur().range;
- if (L.nextIf(TK_TRUE)) {
- return c(TK_CONST, range, {d(1), s("b")});
- } else if (L.nextIf(TK_FALSE)) {
- return c(TK_CONST, range, {d(0), s("b")});
- }
- float mult = 1.0f;
- while (L.nextIf('-')) {
- mult *= -1.0f;
- }
- auto t = L.expect(TK_NUMBER);
- std::string type_ident =
- (t.text().find('.') == std::string::npos) ? "i" : "f";
- if (L.cur().kind == TK_IDENT) {
- Token type_ident_tok = L.expect(TK_IDENT);
- type_ident = type_ident_tok.text();
- if (type_ident != "LL" && type_ident != "f") {
- throw ErrorReport(type_ident_tok)
- << "expected 'f' or 'LL' "
- << "as numeric type identifier but found '" << type_ident << "'";
- }
- }
- return c(TK_CONST, t.range, {d(mult * t.doubleValue()), s(type_ident)});
- }
- TreeRef parseAttributeValue() {
- int kind = L.cur().kind;
- switch (kind) {
- case '[':
- return parseList('[', ',', ']', [&](int i) { return parseConst(); });
- default:
- return parseConst();
- }
- }
- void parseOperatorArguments(TreeList& inputs, TreeList& attributes) {
- L.expect('(');
- if (L.cur().kind != ')') {
- do {
- if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
- auto ident = parseIdent();
- L.expect('=');
- auto v = parseAttributeValue();
- attributes.push_back(Attribute::create(ident->range(), ident, v));
- } else {
- inputs.push_back(parseExp());
- }
- } while (L.nextIf(','));
- }
- L.expect(')');
- }
-
- // OK: [a] (gather), [a:], [:a], [a:b], [:] (slice)
- // Not OK: []
- TreeRef parseSliceOrGather(TreeRef value) {
- const auto range = L.cur().range;
- L.expect('[');
-
- // `first` will either be the gather indices, or the start of the slice.
- TreeRef first, second;
-
- // Here we can either have a colon (which starts a slice), or an expression.
- // If an expression, we don't know yet if it will be a slice or a gather.
- if (L.cur().kind != ':') {
- first = parseExp();
- if (L.nextIf(']')) {
- return Gather::create(range, value, first);
- } else {
- first = c(TK_OPTION, range, {first});
- }
- } else {
- first = c(TK_OPTION, range, {});
- }
- L.expect(':');
- // Now we *may* have an expression.
- if (L.cur().kind != ']') {
- second = c(TK_OPTION, range, {parseExp()});
- } else {
- second = c(TK_OPTION, range, {});
- }
- L.expect(']');
-
- return Slice::create(range, value, first, second);
- }
- TreeRef parseIdentList() {
- return parseList('(', ',', ')', [&](int i) { return parseIdent(); });
- }
- TreeRef parseParam() {
- auto typ = parseType();
- if (L.cur().kind != TK_IDENT && typ->trees()[0]->kind() == TK_IDENT) {
- // oops, it wasn't a type but just a param without any type specified
- return Param::create(
- typ->range(), typ->trees()[0], c(TK_INFERRED, typ->range(), {}));
- }
- auto ident = parseIdent();
- return Param::create(typ->range(), ident, typ);
- }
- // TODO: these functions should be unnecessary, but we currently do not
- // emit a TK_NEWLINE before a series of TK_DEDENT tokens
- // so if we see a TK_DEDENT then we know a newline must have happened and
- // ignore it. The real fix is to patch the lexer so TK_NEWLINE does get
- // emited before a TK_INDENT
- void expectEndOfLine() {
- if (L.cur().kind != TK_DEDENT)
- L.expect(TK_NEWLINE);
- }
- bool isEndOfLine() {
- return L.cur().kind == TK_NEWLINE || L.cur().kind == TK_DEDENT;
- }
-
- // 'first' has already been parsed since expressions can exist
- // alone on a line:
- // first[,other,lhs] = rhs
- TreeRef parseAssign(TreeRef first) {
- TreeRef list = parseOneOrMoreExp(first);
- auto red = parseOptionalReduction();
- auto rhs = parseExp();
- expectEndOfLine();
- return Assign::create(list->range(), list, red, rhs);
- }
- TreeRef parseStmt() {
- switch (L.cur().kind) {
- case TK_IF:
- return parseIf();
- case TK_WHILE:
- return parseWhile();
- case TK_GLOBAL: {
- auto range = L.next().range;
- std::vector<TreeRef> idents;
- do {
- idents.push_back(parseIdent());
- } while (L.nextIf(','));
- expectEndOfLine();
- return c(TK_GLOBAL, range, std::move(idents));
- }
- default: {
- auto r = parseExp();
- if (!isEndOfLine()) {
- return parseAssign(r);
- } else {
- expectEndOfLine();
- return r;
- }
- }
- }
- }
- TreeRef parseScalarType() {
- switch (L.cur().kind) {
- case TK_INT:
- case TK_FLOAT:
- case TK_LONG:
- case TK_DOUBLE: {
- auto t = L.next();
- return c(t.kind, t.range, {});
- }
- default:
- return parseIdent();
- }
- }
- TreeRef parseOptionalIdentList() {
- TreeRef list = nullptr;
- if (L.cur().kind == '(') {
- list = parseIdentList();
- } else {
- list = c(TK_LIST, L.cur().range, {});
- }
- return list;
- }
- TreeRef parseType() {
- auto st = parseScalarType();
- auto list = parseOptionalIdentList();
- return TensorType::create(st->range(), st, list);
- }
- // 'first' has already been parsed, add the rest
- // if they exist
- // first[, the, rest]
- TreeRef parseOneOrMoreExp(TreeRef first) {
- TreeList list{first};
- while (L.nextIf(',')) {
- list.push_back(parseExp());
- }
- return List(list.back()->range(), std::move(list));
- }
- TreeRef parseIf() {
- auto r = L.cur().range;
- L.expect(TK_IF);
- auto cond = parseExp();
- L.expect(':');
- auto true_branch = parseStatements();
- auto false_branch = List(L.cur().range, {});
- if (L.nextIf(TK_ELSE)) {
- L.expect(':');
- false_branch = parseStatements();
- }
- return If::create(r, cond, true_branch, false_branch);
- }
- TreeRef parseWhile() {
- auto r = L.cur().range;
- L.expect(TK_WHILE);
- auto cond = parseExp();
- L.expect(':');
- auto body = parseStatements();
- return While::create(r, cond, body);
- }
- TreeRef parseStatements() {
- auto r = L.cur().range;
- L.expect(TK_INDENT);
- TreeList stmts;
- while (true) {
- stmts.push_back(parseStmt());
- if (L.nextIf(TK_DEDENT))
- break;
- }
- return c(TK_LIST, r, std::move(stmts));
- }
- TreeRef parseFunction() {
- L.expect(TK_DEF);
- auto name = parseIdent();
- auto paramlist =
- parseList('(', ',', ')', [&](int i) { return parseParam(); });
- L.expect(TK_ARROW);
- auto retlist =
- parseList('(', ',', ')', [&](int i) { return parseParam(); });
- L.expect(':');
- auto stmts_list = parseStatements();
- return Def::create(name->range(), name, paramlist, retlist, stmts_list);
- }
- Lexer& lexer() {
- return L;
- }
-
- private:
- // short helpers to create nodes
- TreeRef d(double v) {
- return Number::create(v);
- }
- TreeRef s(const std::string& s) {
- return String::create(s);
- }
- TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
- return Compound::create(kind, range, std::move(trees));
- }
- TreeRef List(const SourceRange& range, TreeList&& trees) {
- return c(TK_LIST, range, std::move(trees));
- }
- Lexer L;
- SharedParserData& shared;
-};
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/tree.h b/caffe2/contrib/script/tree.h
deleted file mode 100644
index c508308..0000000
--- a/caffe2/contrib/script/tree.h
+++ /dev/null
@@ -1,233 +0,0 @@
-#pragma once
-
-#include <memory>
-#include <vector>
-
-#include "caffe2/contrib/script/lexer.h"
-
-namespace caffe2 {
-namespace script {
-
-// Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
-// Rather than have a full class hierarchy for all TC statements,
-// Trees are a slight variation of Lisp S-expressions.
-// for instance the expression a*b+1 is represented as:
-// (+ (* (ident a) (ident b)) (const 1))
-// Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which
-// define stringValue() and doubleValue().
-// Everything else is a Compound object, which has a 'kind' that is a token from
-// Lexer.h's TokenKind enum, and contains a list of subtrees.
-// Like TokenKind single-character operators like '+' are representing using the
-// character itself, so add.kind() == '+'.
-// Compound objects are also always associated with a SourceRange for
-// reporting error message.
-
-// Memory management of trees is done using shared_ptr.
-
-struct Tree;
-using TreeRef = std::shared_ptr<Tree>;
-using TreeList = std::vector<TreeRef>;
-
-static const TreeList empty_trees = {};
-
-struct Tree : std::enable_shared_from_this<Tree> {
- Tree(int kind_) : kind_(kind_) {}
- int kind() const {
- return kind_;
- }
- virtual bool isAtom() const {
- return true;
- }
- virtual const SourceRange& range() const {
- throw std::runtime_error("is an Atom");
- }
- virtual double doubleValue() const {
- throw std::runtime_error("not a TK_NUMBER");
- }
- virtual const std::string& stringValue() const {
- throw std::runtime_error("not a TK_STRING");
- }
- virtual bool boolValue() const {
- throw std::runtime_error("not a TK_BOOL");
- }
- virtual const TreeList& trees() const {
- return empty_trees;
- }
- const TreeRef& tree(size_t i) const {
- return trees().at(i);
- }
- virtual TreeRef map(std::function<TreeRef(TreeRef)> /*fn*/) {
- return shared_from_this();
- }
- template <typename... Args>
- void match(int k, Args&... args) {
- matchD(k, "unknown", 0, args...);
- }
- template <typename... Args>
- void matchD(int k, const char* filename, int lineno, Args&... args) {
- if (kind() != k) {
- std::stringstream ss;
- ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
- << "' but found '" << kind() << "'\n";
- range().highlight(ss);
- throw std::runtime_error(ss.str());
- }
- std::initializer_list<TreeRef*> vars = {&args...};
- if (vars.size() > trees().size()) {
- std::stringstream ss;
- ss << filename << ":" << lineno << ": trying to match " << vars.size()
- << " variables against " << trees().size() << " values in list.\n";
- range().highlight(ss);
- throw std::runtime_error(ss.str());
- }
- size_t i = 0;
- for (TreeRef* v : vars) {
- *v = trees()[i++];
- }
- }
- virtual ~Tree() {}
-
- private:
- int kind_;
-};
-
-struct String : public Tree {
- String(const std::string& value_) : Tree(TK_STRING), value_(value_) {}
- virtual const std::string& stringValue() const override {
- return value_;
- }
- template <typename... Args>
- static TreeRef create(Args&&... args) {
- return std::make_shared<String>(std::forward<Args>(args)...);
- }
-
- private:
- std::string value_;
-};
-struct Number : public Tree {
- Number(double value_) : Tree(TK_NUMBER), value_(value_) {}
- virtual double doubleValue() const override {
- return value_;
- }
- template <typename... Args>
- static TreeRef create(Args&&... args) {
- return std::make_shared<Number>(std::forward<Args>(args)...);
- }
-
- private:
- double value_;
-};
-struct Bool : public Tree {
- Bool(bool value_) : Tree(TK_BOOL), value_(value_) {}
- virtual double doubleValue() const override {
- return value_;
- }
- template <typename... Args>
- static TreeRef create(Args&&... args) {
- return std::make_shared<Bool>(std::forward<Args>(args)...);
- }
-
- private:
- bool value_;
-};
-
-static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
- for (auto t : others) {
- if (t->isAtom())
- continue;
- size_t s = std::min(c.start(), t->range().start());
- size_t e = std::max(c.end(), t->range().end());
- c = SourceRange(c.file_ptr(), s, e);
- }
- return c;
-}
-
-struct Compound : public Tree {
- Compound(int kind, const SourceRange& range_) : Tree(kind), range_(range_) {}
- Compound(int kind, const SourceRange& range_, TreeList&& trees_)
- : Tree(kind),
- range_(mergeRanges(range_, trees_)),
- trees_(std::move(trees_)) {}
- virtual const TreeList& trees() const override {
- return trees_;
- }
- static TreeRef
- create(int kind, const SourceRange& range_, TreeList&& trees_) {
- return std::make_shared<Compound>(kind, range_, std::move(trees_));
- }
- virtual bool isAtom() const override {
- return false;
- }
- virtual TreeRef map(std::function<TreeRef(TreeRef)> fn) override {
- TreeList trees_;
- for (auto& t : trees()) {
- trees_.push_back(fn(t));
- }
- return Compound::create(kind(), range(), std::move(trees_));
- }
- const SourceRange& range() const override {
- return range_;
- }
-
- private:
- SourceRange range_;
- TreeList trees_;
-};
-
-// tree pretty printer
-struct pretty_tree {
- pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {}
- const TreeRef& tree;
- size_t col;
- std::unordered_map<TreeRef, std::string> flat_strings;
- const std::string& get_flat(const TreeRef& t) {
- auto it = flat_strings.find(t);
- if (it != flat_strings.end())
- return it->second;
-
- std::stringstream out;
- switch (t->kind()) {
- case TK_NUMBER:
- out << t->doubleValue();
- break;
- case TK_STRING:
- out << t->stringValue();
- break;
- default:
- out << "(" << kindToString(t->kind());
- for (auto e : t->trees()) {
- out << " " << get_flat(e);
- }
- out << ")";
- break;
- }
- auto it_ = flat_strings.emplace(t, out.str());
- return it_.first->second;
- }
- void print(std::ostream& out, const TreeRef& t, int indent) {
- const std::string& s = get_flat(t);
- if (indent + s.size() < col || t->isAtom()) {
- out << s;
- return;
- }
- std::string k = kindToString(t->kind());
- out << "(" << k;
- for (auto e : t->trees()) {
- out << "\n" << std::string(indent + 2, ' ');
- print(out, e, indent + 2);
- }
- out << ")";
- }
-};
-
-static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) {
- t_.print(out, t_.tree, 0);
- return out << std::endl;
-}
-
-static inline std::ostream& operator<<(std::ostream& out, TreeRef t) {
- return out << pretty_tree(t);
-}
-
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/contrib/script/tree_views.h b/caffe2/contrib/script/tree_views.h
deleted file mode 100644
index 2089333..0000000
--- a/caffe2/contrib/script/tree_views.h
+++ /dev/null
@@ -1,442 +0,0 @@
-#pragma once
-#include "error_report.h"
-#include "tree.h"
-
-namespace caffe2 {
-namespace script {
-
-// TreeView provides a statically-typed way to access the members of a TreeRef
-// instead of using TK_MATCH
-
-struct TreeView {
- explicit TreeView(const TreeRef& tree_) : tree_(tree_) {}
- TreeRef tree() const {
- return tree_;
- }
- const SourceRange& range() const {
- return tree_->range();
- }
- operator TreeRef() const {
- return tree_;
- }
-
- protected:
- TreeRef tree_;
-};
-
-template <typename T>
-struct ListViewIterator {
- ListViewIterator(TreeList::const_iterator it) : it(it) {}
- bool operator!=(const ListViewIterator& rhs) const {
- return it != rhs.it;
- }
- T operator*() const {
- return T(*it);
- }
- void operator++() {
- ++it;
- }
- void operator--() {
- --it;
- }
-
- private:
- TreeList::const_iterator it;
-};
-
-template <typename T>
-struct ListView : public TreeView {
- ListView(const TreeRef& tree) : TreeView(tree) {
- tree->match(TK_LIST);
- }
- typedef ListViewIterator<T> iterator;
- typedef ListViewIterator<T> const_iterator;
- iterator begin() const {
- return iterator(tree_->trees().begin());
- }
- iterator end() const {
- return iterator(tree_->trees().end());
- }
- T operator[](size_t i) const {
- return T(tree_->trees().at(i));
- }
- TreeRef map(std::function<TreeRef(const T&)> fn) {
- return tree_->map([&](TreeRef v) { return fn(T(v)); });
- }
- size_t size() const {
- return tree_->trees().size();
- }
-};
-
-template <typename T>
-struct OptionView : public TreeView {
- explicit OptionView(const TreeRef& tree) : TreeView(tree) {
- C2S_ASSERT(tree, tree->kind() == TK_OPTION);
- }
- bool present() const {
- return tree_->trees().size() > 0;
- }
- T get() const {
- C2S_ASSERT(tree_, present());
- return T(tree_->trees()[0]);
- }
- TreeRef map(std::function<TreeRef(const T&)> fn) {
- return tree_->map([&](TreeRef v) { return fn(T(v)); });
- }
-};
-
-struct Ident : public TreeView {
- // each subclass of TreeView provides:
- // 1. a constructor that takes a TreeRef, and matches it to the right type.
- explicit Ident(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_IDENT, name_);
- }
- // 2. accessors that get underlying information out of the object
- // in this case, we return the name of the identifier, and handle the
- // converstion to a string in the method
- const std::string& name() const {
- return name_->stringValue();
- }
-
- // 3. a static method 'create' that creates the underlying TreeRef object
- // for every TreeRef kind that has a TreeView, the parser always uses
- // (e.g.) Ident::create rather than Compound::Create, this means that
- // changes to the structure of Ident are always made right here rather
- // than both in the parser and in this code
- static TreeRef create(const SourceRange& range, const std::string& name) {
- return Compound::create(TK_IDENT, range, {String::create(name)});
- }
-
- private:
- TreeRef name_;
-};
-
-struct Attribute : public TreeView {
- explicit Attribute(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_ATTRIBUTE, name_, value_);
- }
- Ident name() const {
- return Ident(name_);
- }
- TreeRef value() const {
- return value_;
- }
- static TreeRef create(const SourceRange& range, TreeRef name, TreeRef value) {
- return Compound::create(TK_ATTRIBUTE, range, {name, value});
- }
-
- private:
- TreeRef name_;
- TreeRef value_;
-};
-
-struct Apply : public TreeView {
- explicit Apply(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_APPLY, name_, inputs_, attributes_);
- }
-
- Ident name() const {
- return Ident(name_);
- }
- ListView<TreeRef> inputs() const {
- return ListView<TreeRef>(inputs_);
- }
- ListView<Attribute> attributes() const {
- return ListView<Attribute>(attributes_);
- }
-
- static TreeRef create(
- const SourceRange& range,
- TreeRef name,
- TreeRef inputs,
- TreeRef attributes) {
- return Compound::create(TK_APPLY, range, {name, inputs, attributes});
- }
-
- private:
- TreeRef name_;
- TreeRef inputs_;
- TreeRef attributes_;
-};
-
-struct Slice : public TreeView {
- explicit Slice(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_SLICE, value_, start_, end_);
- }
-
- TreeRef value() const {
- return value_;
- }
-
- OptionView<TreeRef> start() const {
- return OptionView<TreeRef>(start_);
- }
-
- OptionView<TreeRef> end() const {
- return OptionView<TreeRef>(end_);
- }
-
- TreeRef startOr(int alternative) const {
- const auto startOption = start();
- return startOption.present() ? startOption.get() : createInt(alternative);
- }
-
- TreeRef endOr(int alternative) const {
- const auto endOption = end();
- return endOption.present() ? endOption.get() : createInt(alternative);
- }
-
- static TreeRef
- create(const SourceRange& range, TreeRef value, TreeRef start, TreeRef end) {
- return Compound::create(TK_SLICE, range, {value, start, end});
- }
-
- private:
- TreeRef createInt(int value) const {
- return Compound::create(
- TK_CONST, range(), {Number::create(value), String::create("i")});
- }
-
- TreeRef value_;
- TreeRef start_;
- TreeRef end_;
-};
-
-struct Gather : public TreeView {
- explicit Gather(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_GATHER, value_, indices_);
- }
-
- TreeRef value() const {
- return value_;
- }
-
- TreeRef indices() const {
- return indices_;
- }
-
- static TreeRef
- create(const SourceRange& range, TreeRef value, TreeRef indices) {
- return Compound::create(TK_GATHER, range, {value, indices});
- }
-
- private:
- TreeRef value_;
- TreeRef indices_;
-};
-
-struct Cast : public TreeView {
- explicit Cast(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_CAST, type_, input_);
- }
-
- int type() const {
- return type_->kind();
- }
- TreeRef input() const {
- return input_;
- }
-
- static TreeRef create(const SourceRange& range, TreeRef type, TreeRef input) {
- return Compound::create(TK_CAST, range, {type, input});
- }
-
- private:
- TreeRef type_;
- TreeRef input_;
-};
-
-struct TensorType : public TreeView {
- explicit TensorType(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_TENSOR_TYPE, scalar_type_, dims_);
- }
- static TreeRef
- create(const SourceRange& range, TreeRef scalar_type_, TreeRef dims_) {
- return Compound::create(TK_TENSOR_TYPE, range, {scalar_type_, dims_});
- }
- int scalarType() const {
- if (scalar_type_->kind() == TK_IDENT)
- throw ErrorReport(tree_)
- << " TensorType has a symbolic ident " << Ident(scalar_type_).name()
- << " rather than a concrete type";
- return scalar_type_->kind();
- }
- ListView<Ident> dims() const {
- return ListView<Ident>(dims_);
- }
-
- private:
- TreeRef scalar_type_;
- TreeRef dims_;
-};
-
-struct Param : public TreeView {
- explicit Param(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_PARAM, ident_, type_);
- }
- static TreeRef create(const SourceRange& range, TreeRef ident, TreeRef type) {
- return Compound::create(TK_PARAM, range, {ident, type});
- }
- // when the type of a field is statically know the accessors return
- // the wrapped type. for instance here we know ident_ is an identifier
- // so the accessor returns an Ident
- // this means that clients can do p.ident().name() to get the name of the
- // parameter.
- Ident ident() const {
- return Ident(ident_);
- }
- // may be TensorType or TK_INFERRED
- TreeRef type() const {
- return type_;
- }
- bool typeIsInferred() const {
- return type_->kind() == TK_INFERRED;
- }
- // helper for when you know the type is not inferred.
- TensorType tensorType() const {
- return TensorType(type_);
- }
-
- private:
- TreeRef ident_;
- TreeRef type_;
-};
-
-struct Assign : public TreeView {
- explicit Assign(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_ASSIGN, lhs_, reduction_, rhs_);
- }
- static TreeRef create(
- const SourceRange& range,
- TreeRef lhs,
- TreeRef reduction,
- TreeRef rhs) {
- return Compound::create(TK_ASSIGN, range, {lhs, reduction, rhs});
- }
- // when the type of a field is statically know the accessors return
- // the wrapped type. for instance here we know ident_ is an identifier
- // so the accessor returns an Ident
- // this means that clients can do p.ident().name() to get the name of the
- // parameter.
- ListView<TreeRef> lhs() const {
- return ListView<TreeRef>(lhs_);
- }
- int reduction() const {
- return reduction_->kind();
- }
- TreeRef rhs() const {
- return rhs_;
- }
-
- private:
- TreeRef lhs_;
- TreeRef reduction_;
- TreeRef rhs_;
-};
-
-struct Def : public TreeView {
- explicit Def(const TreeRef& tree) : TreeView(tree) {
- tree->match(TK_DEF, name_, paramlist, retlist, stmts_list);
- }
- Ident name() const {
- return Ident(name_);
- }
- // ListView helps turn TK_LISTs into vectors of TreeViews
- // so that we can, e.g., return lists of parameters
- ListView<Param> params() const {
- return ListView<Param>(paramlist);
- }
- ListView<Param> returns() const {
- return ListView<Param>(retlist);
- }
- ListView<TreeRef> statements() const {
- return ListView<TreeRef>(stmts_list);
- }
- static TreeRef create(
- const SourceRange& range,
- TreeRef name,
- TreeRef paramlist,
- TreeRef retlist,
- TreeRef stmts_list) {
- return Compound::create(
- TK_DEF, range, {name, paramlist, retlist, stmts_list});
- }
-
- private:
- TreeRef name_;
- TreeRef paramlist;
- TreeRef retlist;
- TreeRef stmts_list;
-};
-
-struct Select : public TreeView {
- explicit Select(const TreeRef& tree) : TreeView(tree) {
- tree_->match('.', value_, selector_);
- }
- TreeRef value() const {
- return value_;
- }
- Ident selector() const {
- return Ident(selector_);
- }
- static TreeRef
- create(const SourceRange& range, TreeRef value, TreeRef selector) {
- return Compound::create('.', range, {value, selector});
- }
-
- private:
- TreeRef value_;
- TreeRef selector_;
-};
-
-struct If : public TreeView {
- explicit If(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_IF, cond_, true_branch_, false_branch_);
- }
- const TreeRef& cond() const {
- return cond_;
- }
- ListView<TreeRef> trueBranch() const {
- return ListView<TreeRef>(true_branch_);
- }
- ListView<TreeRef> falseBranch() const {
- return ListView<TreeRef>(false_branch_);
- }
-
- static TreeRef create(
- const SourceRange& range,
- TreeRef cond_,
- TreeRef true_branch_,
- TreeRef false_branch_) {
- return Compound::create(TK_IF, range, {cond_, true_branch_, false_branch_});
- }
-
- private:
- TreeRef cond_;
- TreeRef true_branch_;
- TreeRef false_branch_;
-};
-
-struct While : public TreeView {
- explicit While(const TreeRef& tree) : TreeView(tree) {
- tree_->match(TK_WHILE, cond_, body_);
- }
- const TreeRef& cond() const {
- return cond_;
- }
- ListView<TreeRef> body() const {
- return ListView<TreeRef>(body_);
- }
-
- static TreeRef
- create(const SourceRange& range, TreeRef cond_, TreeRef body_) {
- return Compound::create(TK_WHILE, range, {cond_, body_});
- }
-
- private:
- TreeRef cond_;
- TreeRef body_;
-};
-
-} // namespace script
-} // namespace caffe2
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 709879a..a4a1509 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -6,7 +6,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
-#include "caffe2/contrib/script/compiler.h"
#include "caffe2/core/asan.h"
#include "caffe2/core/blob_stats.h"
#include "caffe2/core/db.h"
@@ -938,29 +937,6 @@
}
return pyout;
});
-
- py::class_<script::CompilationUnit>(m, "CompilationUnit")
- .def(py::init<>())
- .def("define", &script::CompilationUnit::define)
- .def("get_proto", &script::CompilationUnit::getProto)
- .def(
- "create_net",
- [](script::CompilationUnit* self, const std::string& name) {
- auto net = self->createNet(gWorkspace, name);
- CAFFE_ENFORCE(net);
- return net;
- })
- .def(
- "extern",
- [](script::CompilationUnit* self,
- const std::string& name,
- py::object py_proto) {
- py::bytes bytes = py_proto.attr("SerializeToString")();
- std::unique_ptr<caffe2::NetDef> proto(new NetDef());
- CAFFE_ENFORCE(ParseProtoFromLargeString(
- bytes.cast<std::string>(), proto.get()));
- self->defineExtern(name, std::move(proto));
- });
}
void addGlobalMethods(py::module& m) {