| #pragma once |
| |
| #include <ATen/Backtrace.h> |
| #include <ATen/core/functional.h> |
| #include <ATen/core/stack.h> |
| #include <c10/util/Exception.h> |
| #include <torch/csrc/WindowsTorchApiMacro.h> |
| #include <torch/csrc/autograd/function_hook.h> |
| #include <torch/csrc/autograd/variable.h> |
| #include <torch/csrc/jit/constants.h> |
| #include <torch/csrc/jit/ir.h> |
| #include <torch/csrc/jit/tracing_state.h> |
| #include <torch/csrc/utils/variadic.h> |
| |
| #include <cstdint> |
| #include <iostream> |
| #include <memory> |
| #include <mutex> |
| #include <unordered_map> |
| #include <vector> |
| |
| namespace torch { |
| namespace jit { |
| |
| namespace script { |
| struct Module; |
| } |
| |
| namespace tracer { |
| |
| using ::c10::ivalue::List; |
| using ::c10::ivalue::Shared; |
| |
| using ::c10::IValue; |
| using ::c10::ivalue::Future; |
| using ::c10::ivalue::Tuple; |
| |
| using ::c10::ivalue::BoolList; |
| using ::c10::ivalue::DoubleList; |
| using ::c10::ivalue::GenericList; |
| using ::c10::ivalue::IntList; |
| using ::c10::ivalue::TensorList; |
| |
| using ::c10::ivalue::ConstantString; |
| |
| using torch::autograd::Variable; |
| using variable_list = std::vector<Variable>; |
| |
| TORCH_API void recordSourceLocation(Node* n); |
| TORCH_API void setRecordSourceLocation(void (*v)(Node*)); |
| |
| // Having finished adding a new 'node' to the graph IR 'setValueTrace' |
| // associates this node with an output variable, so that further operations |
| // involving this variable know which node in the IR to reference. |
| TORCH_API void setValueTrace(const IValue& v, Value* value); |
| |
| TORCH_API void delValueTrace(const Variable& var); |
| |
| TORCH_API std::function<void()> pauseTracing(); |
| |
| TORCH_API Value* getValueTrace(const IValue& var); |
| |
| TORCH_API Value* getNestedValueTrace(const IValue& v); |
| |
| TORCH_API Value* getOutputTrace( |
| const std::shared_ptr<TracingState>& state, |
| const Variable& var); |
| |
| TORCH_API Value* getNestedOutputTrace( |
| const std::shared_ptr<TracingState>& state, |
| const IValue& iv); |
| |
| struct TypedStack : public std::pair<Stack, TupleTypePtr> |
| { |
| using pair::pair; |
| |
| // NB: The inherited default constructor gives nullptr for |type|, |
| // so we provide a saner one. |
| TypedStack() |
| : pair({}, TupleType::create({})) |
| {} |
| |
| Stack& stack() { |
| return this->first; |
| } |
| TupleTypePtr& types() { |
| return this->second; |
| } |
| size_t size() { |
| auto s = stack().size(); |
| AT_ASSERT(s == types()->elements().size()); |
| return s; |
| } |
| }; |
| |
| TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs, const std::shared_ptr<script::Module>& self=nullptr); |
| |
| TORCH_API void exit(const Stack& outputs); |
| |
| TORCH_API void abandon(); |
| |
| // NB: those serve both as an intermediate steps in addInputs below, |
| // as well as the overloads that terminate template recursion |
| TORCH_API void addInputs(Node* n, const char* name, int64_t value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| c10::optional<int64_t> value); |
| TORCH_API void addInputs(Node* n, const char* name, bool value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const c10::optional<bool>& value); |
| TORCH_API void addInputs(Node* n, const char* name, double value); |
| TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const c10::optional<at::Scalar>& value); |
| TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value); |
| TORCH_API void addInputs(Node* n, const char* name, at::IntArrayRef value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| at::TensorList value, |
| bool allow_undefined = false); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const ArrayRef<double>& value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const std::vector<double>& value); |
| TORCH_API void addInputs(Node* n, const char* name, const std::string& value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const at::TensorOptions& value); |
| TORCH_API void addInputs(Node* n, const char* name, at::Device value); |
| TORCH_API void addInputs(Node* n, const char* name, at::Layout value); |
| TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value); |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const c10::optional<at::ScalarType>& value); |
| TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value); |
| TORCH_API void addInputs(Node* n, const char* name, at::Generator* value); |
| |
| template<typename T> |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const std::vector<T>& value); |
| |
| template<typename K, typename V> |
| TORCH_API void addInputs( |
| Node* n, |
| const char* name, |
| const std::unordered_map<K, V>& value); |
| |
| template<typename T> |
| void addInputs( |
| Node* n, |
| const char* name, |
| const std::vector<T>& value) { |
| AT_ERROR("Tracing a list of arbitrary type is currently not supported!"); |
| } |
| template<typename K, typename V> |
| void addInputs( |
| Node* n, |
| const char* name, |
| const std::unordered_map<K, V>& value) { |
| AT_ERROR("Tracing a dict of arbitrary types is currently not supported!"); |
| } |
| |
| template <size_t N> |
| void addInputs(Node* n, const char* name, std::array<bool, N> value) { |
| throw std::runtime_error( |
| "Found an unsupported argument type in the JIT tracer. File a bug report."); |
| } |
| |
| TORCH_API void ensureUniqueIfOutOfPlaced( |
| const char* name, |
| const at::Tensor& tensor); |
| |
| template < |
| typename T, |
| typename = torch::enable_if_t< |
| (!std::is_convertible<torch::decay_t<T>, at::TensorList>::value && |
| !std::is_convertible<torch::decay_t<T>, at::Tensor>::value)>> |
| void addOutput(Node* node, T&&) { |
| AT_ERROR( |
| "Found an unsupported argument type ", |
| c10::demangle_type<T>(), |
| " in the JIT tracer. File a bug report."); |
| } |
| TORCH_API void addOutput(Node* node, const at::Tensor& tensor); |
| TORCH_API void setOutput(Value* value, const at::Tensor& output); |
| TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list); |
| |
| TORCH_API autograd::Variable getSizeOf( |
| const autograd::Variable& var, |
| int64_t dim); |
| |
| } // namespace tracer |
| } // namespace jit |
| } // namespace torch |