Load original SourceRanges on import (#22180)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22180
ghimport-source-id: efa46dcb845c099f0a746f523901ab2c2cd3b004
Test Plan: Imported from OSS
Differential Revision: D15981425
Pulled By: jamesr66a
fbshipit-source-id: bef682bd13c1a5be95bdb97e025690c6f2d523d3
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 2a337c2..1c918e7 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -440,6 +440,7 @@
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
diff --git a/test/jit_utils.py b/test/jit_utils.py
index 49616058..fd1aefb 100644
--- a/test/jit_utils.py
+++ b/test/jit_utils.py
@@ -24,6 +24,7 @@
import io
import math
import os
+import pickle
import tempfile
import textwrap
@@ -120,6 +121,8 @@
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
main_module = archive.open('archive/code/archive.py')
main_module_code = "".join([line.decode() for line in main_module])
+ main_module_debug_file = archive.open('archive/debug/archive.pkl')
+ main_module_debug = pickle.load(main_module_debug_file)
except RuntimeError as e:
if not self._isHookExceptionOk(e):
raise
@@ -138,8 +141,11 @@
archive2 = zipfile.ZipFile(saved_module_buffer_2)
main_module_2 = archive2.open('archive/code/archive.py')
main_module_2_code = "".join([line.decode() for line in main_module_2])
+ main_module_2_debug_file = archive.open('archive/debug/archive.pkl')
+ main_module_2_debug = pickle.load(main_module_2_debug_file)
self.assertMultiLineEqual(main_module_code, main_module_2_code)
+ self.assertEqual(main_module_debug, main_module_2_debug)
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
if isinstance(m, torch._C.Function):
diff --git a/test/test_jit.py b/test/test_jit.py
index c619abe..46242b1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3323,6 +3323,82 @@
fc.run(scripted.graph)
fc.run(str(scripted.graph))
+ def test_serialized_source_ranges(self):
+
+ class FooTest(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x, w):
+ return torch.mm(x, w.t())
+
+ ft = FooTest()
+ loaded = self.getExportImportCopy(ft)
+ _, lineno = inspect.getsourcelines(FooTest)
+
+ with self.assertRaisesRegex(RuntimeError, 'test_jit.py:{}'.format(lineno + 3)):
+ loaded(torch.rand(3, 4), torch.rand(30, 40))
+
+ def test_serialized_source_ranges2(self):
+
+ class FooTest2(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self):
+ raise RuntimeError('foo')
+
+ _, lineno = inspect.getsourcelines(FooTest2)
+
+ with self.assertRaisesRegex(torch._C.JITException, 'test_jit.py:{}'.format(lineno + 3)):
+ ft = FooTest2()
+ loaded = self.getExportImportCopy(ft)
+ loaded()
+
+ def test_serialized_source_ranges_dont_jitter(self):
+ class FooTest3(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, lim):
+ first = 1
+ second = 1
+ i = 1
+ somenum = 5
+ dontmutateme = 3
+ third = 0
+ while bool(i < lim):
+ third = first + second
+ first = second
+ second = third
+ j = 0
+ while j < 10:
+ somenum = somenum * 2
+ j = j + 1
+ i = i + j
+ i = i + dontmutateme
+
+ st = second + third
+ fs = first + second
+ return third, st, fs
+
+ ft3 = FooTest3()
+
+ def debug_records_from_mod(mod):
+ buffer = io.BytesIO()
+ torch.jit.save(ft3, buffer)
+ buffer.seek(0)
+ archive = zipfile.ZipFile(buffer)
+ debug_file = archive.open('archive/debug/archive.pkl')
+ return pickle.load(debug_file), buffer
+
+ records1, buffer = debug_records_from_mod(ft3)
+
+ buffer.seek(0)
+ loaded = torch.jit.load(buffer)
+ records2, buffer = debug_records_from_mod(loaded)
+
+ buffer.seek(0)
+ loaded2 = torch.jit.load(buffer)
+ records3, _ = debug_records_from_mod(loaded2)
+
+ self.assertEqual(records1, records2)
+ self.assertEqual(records2, records3)
+
def test_tensor_shape(self):
x = torch.empty(34, 56, 78)
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 6c7ea10..50e6ddd 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -123,6 +123,7 @@
"torch/csrc/jit/script/class_type.cpp",
"torch/csrc/jit/script/parser.cpp",
"torch/csrc/jit/script/jit_exception.cpp",
+ "torch/csrc/jit/source_range_serialization.cpp",
"torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/jit/import_source.cpp",
"torch/csrc/jit/hooks_for_testing.cpp",
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 63fef41..c8b10eb 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -11,7 +11,7 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickler.h>
-#include <torch/csrc/jit/source_range_serializer.h>
+#include <torch/csrc/jit/source_range_serialization.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
@@ -931,21 +931,14 @@
// Write out debug records
torch::RecordRef* debug_record =
module_def->mutable_torchscript_debug_arena();
- Pickler p;
- SourceRangeSerializer srs;
- p.start();
- p.startTuple();
- for (const auto& range : source_ranges) {
- std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
- srs.serialize(range.range)};
- p.addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
- }
- p.endTuple();
- p.finish();
+
+ SourceRangePickler source_range_pickler;
+ source_range_pickler.pickle(source_ranges);
+ const auto& range_data = source_range_pickler.get_data();
std::stringstream debug_filename;
debug_filename << "debug/" << module_name.str() << ".pkl";
writer_.writeRecord(
- debug_filename.str(), p.stack().data(), p.stack().size());
+ debug_filename.str(), range_data.data(), range_data.size());
debug_record->set_key(debug_filename.str());
}
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 980772f..9f3db62 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -9,6 +9,8 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/script/script_type_parser.h>
+#include <torch/csrc/jit/source_range_serialization.h>
+#include <torch/csrc/jit/source_range_serialization_impl.h>
#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
@@ -328,6 +330,20 @@
module.register_attribute(
attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
}
+
+ // If present, load in the table of source ranges from the original
+ // generating code.
+ std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
+ if (module_def.has_torchscript_debug_arena()) {
+ at::DataPtr data;
+ size_t size;
+ std::tie(data, size) =
+ reader_.getRecord(module_def.torchscript_debug_arena().key());
+
+ gen_ranges =
+ std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
+ }
+
if (module_def.has_torchscript_arena()) {
at::DataPtr data;
size_t size;
@@ -337,7 +353,8 @@
auto src = std::make_shared<Source>(
std::string(static_cast<const char*>(data.get()), size),
module_def.torchscript_arena().key(),
- 1);
+ 1,
+ std::move(gen_ranges));
std::function<void(const std::string&)> import_callback =
[this](const std::string& qualifier) { importCallback(qualifier); };
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 000e32e..8ccb7e6 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -171,10 +171,13 @@
SourceRangeStack source_range_stack_ = {SourceRange("")};
struct WithSourceRange {
- explicit WithSourceRange(SourceRangeStack* stack, SourceRange sr)
- : stack(stack) {
+ explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
TORCH_INTERNAL_ASSERT(stack);
- stack->push_back(std::move(sr));
+ if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
+ stack->push_back(std::move(gen_source.value()));
+ } else {
+ stack->push_back(std::move(n->sourceRange()));
+ }
}
~WithSourceRange() {
@@ -190,6 +193,13 @@
TaggedStringStream(TaggedStringStream&& rhs) = default;
TaggedStringStream& operator<<(const std::string& s) {
+ // This prevents having redundant entries at the same offset,
+ // which can happen for example in printValueList when begin
+ // and end are the empty string.
+ if (s.size() == 0) {
+ return *this;
+ }
+
if (!ranges_.size() || ranges_.back().range != srs_->back()) {
ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
}
@@ -233,6 +243,21 @@
return ranges_;
}
+ // Write out this TaggedStringStream's text and source ranges to
+ // os and source_ranges_out, respectively. stream_pos gives
+ // the byte offset into the current stream, so we can accurately
+ // record source ranges as byte offsets.
+ void print(
+ std::ostream& os,
+ SourceRangeRecords* source_ranges_out,
+ int64_t stream_pos) {
+ os << str();
+ for (const auto& x : ranges()) {
+ source_ranges_out->push_back(x);
+ source_ranges_out->back().bytes += stream_pos;
+ }
+ }
+
private:
std::ostringstream oss_;
std::vector<TaggedRange> ranges_;
@@ -756,7 +781,7 @@
}
void printNode(Node* node, bool print_const) {
- WithSourceRange guard(&source_range_stack_, node->sourceRange());
+ WithSourceRange guard(&source_range_stack_, node);
// Check for class dependencies. If this node inputs or outputs a class
// type, we need to add it to our table of dependencies.
for (const auto input : node->inputs()) {
@@ -1149,8 +1174,7 @@
Graph& graph = *func.graph();
used_names_.clear(); // each graph can reuse local names
- WithSourceRange guard(
- &source_range_stack_, graph.param_node()->sourceRange());
+ WithSourceRange guard(&source_range_stack_, graph.param_node());
indent();
body_ << "def " << func.name() << "(";
@@ -1250,8 +1274,9 @@
}
void print(std::ostream& out, SourceRangeRecords& source_ranges_out) {
- out << getImports() << body_.str();
- source_ranges_out = body_.ranges();
+ out << getImports();
+ int64_t source_offset = out.tellp();
+ body_.print(out, &source_ranges_out, source_offset);
}
};
diff --git a/torch/csrc/jit/passes/python_print.h b/torch/csrc/jit/passes/python_print.h
index 9ee0d9f..8bdeaec 100644
--- a/torch/csrc/jit/passes/python_print.h
+++ b/torch/csrc/jit/passes/python_print.h
@@ -12,16 +12,6 @@
struct Module;
} // namespace script
-// A pair of (byte offset, SourceRange) describing a specific segment
-// of the output stream
-struct TaggedRange {
- TaggedRange(size_t bytes, SourceRange range)
- : bytes(bytes), range(std::move(range)) {}
- size_t bytes;
- SourceRange range;
-};
-using SourceRangeRecords = std::vector<TaggedRange>;
-
TORCH_API void PythonPrint(
std::ostream& out,
SourceRangeRecords& source_ranges_out,
diff --git a/torch/csrc/jit/source_range.cpp b/torch/csrc/jit/source_range.cpp
index 578e6dd..8a5e573 100644
--- a/torch/csrc/jit/source_range.cpp
+++ b/torch/csrc/jit/source_range.cpp
@@ -1,8 +1,17 @@
#include <torch/csrc/jit/source_range.h>
+#include <torch/csrc/jit/source_range_serialization.h>
namespace torch {
namespace jit {
+c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
+ const SourceRange& range) {
+ if (!gen_ranges_) {
+ return c10::nullopt;
+ }
+ return gen_ranges_->findSourceRangeThatGenerated(range);
+}
+
// a range of a shared string 'file_' with
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
const std::string& str = source_->text();
@@ -56,6 +65,13 @@
out << str.substr(end_line, end_highlight - end_line);
if (!str.empty() && str.back() != '\n')
out << "\n";
+ // Retrieve original SourceRange, if present.
+ if (source_) {
+ if (auto orig_source_range = findSourceRangeThatGenerated()) {
+ out << "Compiled from code ";
+ orig_source_range->highlight(out);
+ }
+ }
}
} // namespace jit
diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h
index 5362457..4176e79 100644
--- a/torch/csrc/jit/source_range.h
+++ b/torch/csrc/jit/source_range.h
@@ -8,6 +8,9 @@
namespace torch {
namespace jit {
+struct SourceRangeUnpickler;
+struct SourceRange;
+
// Source represents a code segment. It keeps track of:
// - text : the text of the code segment
// - filename (optional) : if present, represents the name of the file from
@@ -15,18 +18,25 @@
// - starting_line_no : represents the line in the original file where the
// code segment started.
struct Source {
- explicit Source(std::string text)
- : text_(std::move(text)), filename_(c10::nullopt) {
+ explicit Source(
+ std::string text,
+ std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
+ : text_(std::move(text)),
+ filename_(c10::nullopt),
+ starting_line_no_(0),
+ gen_ranges_(std::move(gen_ranges)) {
calc_line_start_offsets();
}
Source(
std::string text,
c10::optional<std::string> filename,
- size_t starting_line_no)
+ size_t starting_line_no,
+ std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_(std::move(text)),
filename_(std::move(filename)),
- starting_line_no_(starting_line_no) {
+ starting_line_no_(starting_line_no),
+ gen_ranges_(std::move(gen_ranges)) {
calc_line_start_offsets();
}
@@ -67,6 +77,9 @@
return starting_line_no_;
}
+ c10::optional<SourceRange> findSourceRangeThatGenerated(
+ const SourceRange& range);
+
private:
void calc_line_start_offsets() {
size_t pos = 0;
@@ -82,6 +95,8 @@
// Starting offsets for lines into the source. e.g. line 0 starts at
// line_starting_offsets_[0], etc.
std::vector<size_t> line_starting_offsets_;
+
+ std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
};
// A SourceRange is a view into a Source, that points to a subset of the source,
@@ -139,6 +154,13 @@
bool operator!=(const SourceRange& rhs) const {
return !(*this == rhs);
}
+
+ c10::optional<SourceRange> findSourceRangeThatGenerated() const {
+ if (!source_) {
+ return c10::nullopt;
+ }
+ return source_->findSourceRangeThatGenerated(*this);
+ }
private:
std::shared_ptr<Source> source_;
@@ -151,5 +173,15 @@
return out;
}
+// A pair of (byte offset, SourceRange) describing a specific segment
+// of the output stream
+struct TaggedRange {
+ TaggedRange(size_t bytes, SourceRange range)
+ : bytes(bytes), range(std::move(range)) {}
+ size_t bytes;
+ SourceRange range;
+};
+using SourceRangeRecords = std::vector<TaggedRange>;
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/source_range_serialization.cpp b/torch/csrc/jit/source_range_serialization.cpp
new file mode 100644
index 0000000..f42cbce
--- /dev/null
+++ b/torch/csrc/jit/source_range_serialization.cpp
@@ -0,0 +1,148 @@
+#include <torch/csrc/jit/source_range_serialization.h>
+#include <torch/csrc/jit/source_range_serialization_impl.h>
+
+#include <ATen/core/ivalue.h>
+#include <torch/csrc/jit/pickler.h>
+
+namespace torch {
+namespace jit {
+
+class SourceRangeSerializer {
+ public:
+ // Serialize SourceRange as Tuple[SourceType, int, int]
+ // where SourceType = Tuple[str, Optional[str], int, List[int]],
+ // the serialized form of Source
+ c10::IValue serialize(const SourceRange& sr);
+
+ private:
+ // Serialize Source as Tuple[str, Optional[str], int, List[int]]
+ // This caches serialized sources, since many SourceRanges can
+ // refer to the same one.
+ c10::IValue serialize_source(const std::shared_ptr<Source>& s);
+
+ std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
+};
+
+class SourceRangeDeserializer {
+ public:
+ SourceRange deserialize(const c10::IValue& iv) {
+ auto tup_elems = iv.toTuple()->elements();
+ TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
+ std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
+ int64_t start_ = tup_elems[1].toInt();
+ int64_t end_ = tup_elems[2].toInt();
+ return SourceRange(source_, start_, end_);
+ }
+
+ private:
+ std::shared_ptr<Source> deserialize_source(const c10::IValue& iv) {
+ auto tup = iv.toTuple();
+ if (cached_sources.count(tup)) {
+ return cached_sources.at(tup);
+ }
+
+ auto tup_elems = tup->elements();
+ TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
+ std::string text_ = tup_elems[0].toString()->string();
+ c10::optional<std::string> filename_ =
+ tup_elems[1].toOptional<std::string>();
+ int64_t starting_line_no_ = tup_elems[2].toInt();
+
+ auto source = std::make_shared<Source>(
+ std::move(text_), std::move(filename_), starting_line_no_);
+ cached_sources[tup] = source;
+ return source;
+ }
+
+ std::unordered_map<
+ c10::intrusive_ptr<c10::ivalue::Tuple>,
+ std::shared_ptr<Source>>
+ cached_sources;
+};
+
+c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
+ std::vector<c10::IValue> elements = {
+ serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
+ return c10::ivalue::Tuple::create(std::move(elements));
+}
+
+c10::IValue SourceRangeSerializer::serialize_source(
+ const std::shared_ptr<Source>& s) {
+ if (serialized_sources.count(s)) {
+ return serialized_sources.at(s);
+ }
+ std::vector<c10::IValue> elements{
+ s->text(), s->filename(), (int64_t)s->starting_line_no()};
+ auto serialized = c10::ivalue::Tuple::create(std::move(elements));
+ serialized_sources[s] = serialized;
+ return serialized;
+}
+
+SourceRangePickler::SourceRangePickler()
+ : p(new Pickler()), srs(new SourceRangeSerializer()) {}
+
+void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
+ p->start();
+ p->startTuple();
+ for (const auto& range : ranges) {
+ std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
+ srs->serialize(range.range)};
+ p->addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
+ }
+ p->endTuple();
+ p->finish();
+}
+
+const std::vector<char>& SourceRangePickler::get_data() {
+ return p->stack();
+}
+
+ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
+ at::DataPtr&& data,
+ size_t size)
+ : data(std::move(data)),
+ size(size),
+ deserializer(new SourceRangeDeserializer()),
+ unpickled_records(nullptr) {}
+
+void ConcreteSourceRangeUnpickler::unpickle() {
+ if (unpickled_records) {
+ return;
+ }
+
+ Unpickler up(data.get(), size, nullptr);
+ auto ivalues = up.parse_ivalue_list();
+
+ unpickled_records = std::make_shared<SourceRangeRecords>();
+ for (auto& val : ivalues) {
+ auto tup_elems = val.toTuple()->elements();
+ int64_t offset = tup_elems[0].toInt();
+ auto source_range = deserializer->deserialize(tup_elems[1]);
+ unpickled_records->emplace_back(offset, std::move(source_range));
+ }
+}
+
+c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
+ findSourceRangeThatGenerated(const SourceRange& range) {
+ unpickle();
+
+ auto query = TaggedRange(range.start(), SourceRange{""});
+ auto entry = std::upper_bound(
+ unpickled_records->begin(),
+ unpickled_records->end(),
+ query,
+ [](const TaggedRange& a, const TaggedRange& b) -> bool {
+ return a.bytes < b.bytes;
+ });
+
+ // NB: must decrement iterator since upper_bound finds the element
+ // *greater than* the query.
+ if (entry != unpickled_records->begin()) {
+ return (entry - 1)->range;
+ }
+
+ return c10::nullopt;
+}
+
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serialization.h b/torch/csrc/jit/source_range_serialization.h
new file mode 100644
index 0000000..349ef16
--- /dev/null
+++ b/torch/csrc/jit/source_range_serialization.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include <c10/core/Allocator.h>
+#include <torch/csrc/jit/source_range.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace c10 {
+struct IValue;
+}
+
+namespace torch {
+namespace jit {
+
+class Pickler;
+class SourceRangeSerializer;
+class SourceRangeDeserializer;
+
+class SourceRangePickler {
+ public:
+ SourceRangePickler();
+
+ void pickle(const SourceRangeRecords& ranges);
+
+ const std::vector<char>& get_data();
+
+ private:
+ std::shared_ptr<Pickler> p;
+ std::shared_ptr<SourceRangeSerializer> srs;
+};
+
+class SourceRangeUnpickler {
+ public:
+ virtual c10::optional<SourceRange> findSourceRangeThatGenerated(
+ const SourceRange& range) = 0;
+
+ virtual ~SourceRangeUnpickler() {}
+};
+
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serialization_impl.h b/torch/csrc/jit/source_range_serialization_impl.h
new file mode 100644
index 0000000..1e1bb10
--- /dev/null
+++ b/torch/csrc/jit/source_range_serialization_impl.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <torch/csrc/jit/pickler.h>
+#include <torch/csrc/jit/source_range_serialization.h>
+
+namespace torch {
+namespace jit {
+
+// Do this clownyness with virtual functions because of the split
+// between ATen core and torch
+
+class ConcreteSourceRangeUnpickler : public SourceRangeUnpickler {
+ public:
+ ConcreteSourceRangeUnpickler(at::DataPtr&& data, size_t size);
+
+ c10::optional<SourceRange> findSourceRangeThatGenerated(
+ const SourceRange& range) override;
+
+ private:
+ at::DataPtr data;
+ size_t size;
+
+ void unpickle();
+
+ std::shared_ptr<SourceRangeDeserializer> deserializer;
+ std::shared_ptr<SourceRangeRecords> unpickled_records;
+};
+
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serializer.h b/torch/csrc/jit/source_range_serializer.h
deleted file mode 100644
index 8c10e18..0000000
--- a/torch/csrc/jit/source_range_serializer.h
+++ /dev/null
@@ -1,39 +0,0 @@
-#pragma once
-
-#include <ATen/core/ivalue.h>
-#include <torch/csrc/jit/source_range.h>
-
-namespace torch {
-namespace jit {
-
-class SourceRangeSerializer {
- public:
- // Serialize SourceRange as Tuple[SourceType, int, int]
- // where SourceType = Tuple[str, Optional[str], int, List[int]],
- // the serialized form of Source
- c10::IValue serialize(const SourceRange& sr) {
- std::vector<c10::IValue> elements = {
- serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
- return c10::ivalue::Tuple::create(std::move(elements));
- }
-
- private:
- // Serialize Source as Tuple[str, Optional[str], int, List[int]]
- // This caches serialized sources, since many SourceRanges can
- // refer to the same one.
- c10::IValue serialize_source(const std::shared_ptr<Source>& s) {
- if (serialized_sources.count(s)) {
- return serialized_sources.at(s);
- }
- std::vector<c10::IValue> elements{
- s->text(), s->filename(), (int64_t)s->starting_line_no()};
- auto serialized = c10::ivalue::Tuple::create(std::move(elements));
- serialized_sources[s] = serialized;
- return serialized;
- }
-
- std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
-};
-
-} // namespace jit
-} // namespace torch
\ No newline at end of file