Load original SourceRanges on import (#22180)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22180
ghimport-source-id: efa46dcb845c099f0a746f523901ab2c2cd3b004

Test Plan: Imported from OSS

Differential Revision: D15981425

Pulled By: jamesr66a

fbshipit-source-id: bef682bd13c1a5be95bdb97e025690c6f2d523d3
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 2a337c2..1c918e7 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -440,6 +440,7 @@
     ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
     ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
     ${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
+    ${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
     ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
     ${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
diff --git a/test/jit_utils.py b/test/jit_utils.py
index 49616058..fd1aefb 100644
--- a/test/jit_utils.py
+++ b/test/jit_utils.py
@@ -24,6 +24,7 @@
 import io
 import math
 import os
+import pickle
 import tempfile
 import textwrap
 
@@ -120,6 +121,8 @@
                 self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
                 main_module = archive.open('archive/code/archive.py')
                 main_module_code = "".join([line.decode() for line in main_module])
+                main_module_debug_file = archive.open('archive/debug/archive.pkl')
+                main_module_debug = pickle.load(main_module_debug_file)
             except RuntimeError as e:
                 if not self._isHookExceptionOk(e):
                     raise
@@ -138,8 +141,11 @@
             archive2 = zipfile.ZipFile(saved_module_buffer_2)
             main_module_2 = archive2.open('archive/code/archive.py')
             main_module_2_code = "".join([line.decode() for line in main_module_2])
+            main_module_2_debug_file = archive.open('archive/debug/archive.pkl')
+            main_module_2_debug = pickle.load(main_module_2_debug_file)
 
             self.assertMultiLineEqual(main_module_code, main_module_2_code)
+            self.assertEqual(main_module_debug, main_module_2_debug)
 
     def getExportImportCopy(self, m, also_test_file=True, map_location=None):
         if isinstance(m, torch._C.Function):
diff --git a/test/test_jit.py b/test/test_jit.py
index c619abe..46242b1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3323,6 +3323,82 @@
         fc.run(scripted.graph)
         fc.run(str(scripted.graph))
 
+    def test_serialized_source_ranges(self):
+
+        class FooTest(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x, w):
+                return torch.mm(x, w.t())
+
+        ft = FooTest()
+        loaded = self.getExportImportCopy(ft)
+        _, lineno = inspect.getsourcelines(FooTest)
+
+        with self.assertRaisesRegex(RuntimeError, 'test_jit.py:{}'.format(lineno + 3)):
+            loaded(torch.rand(3, 4), torch.rand(30, 40))
+
+    def test_serialized_source_ranges2(self):
+
+        class FooTest2(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self):
+                raise RuntimeError('foo')
+
+        _, lineno = inspect.getsourcelines(FooTest2)
+
+        with self.assertRaisesRegex(torch._C.JITException, 'test_jit.py:{}'.format(lineno + 3)):
+            ft = FooTest2()
+            loaded = self.getExportImportCopy(ft)
+            loaded()
+
+    def test_serialized_source_ranges_dont_jitter(self):
+        class FooTest3(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, lim):
+                first = 1
+                second = 1
+                i = 1
+                somenum = 5
+                dontmutateme = 3
+                third = 0
+                while bool(i < lim):
+                    third = first + second
+                    first = second
+                    second = third
+                    j = 0
+                    while j < 10:
+                        somenum = somenum * 2
+                        j = j + 1
+                    i = i + j
+                    i = i + dontmutateme
+
+                st = second + third
+                fs = first + second
+                return third, st, fs
+
+        ft3 = FooTest3()
+
+        def debug_records_from_mod(mod):
+            buffer = io.BytesIO()
+            torch.jit.save(ft3, buffer)
+            buffer.seek(0)
+            archive = zipfile.ZipFile(buffer)
+            debug_file = archive.open('archive/debug/archive.pkl')
+            return pickle.load(debug_file), buffer
+
+        records1, buffer = debug_records_from_mod(ft3)
+
+        buffer.seek(0)
+        loaded = torch.jit.load(buffer)
+        records2, buffer = debug_records_from_mod(loaded)
+
+        buffer.seek(0)
+        loaded2 = torch.jit.load(buffer)
+        records3, _ = debug_records_from_mod(loaded2)
+
+        self.assertEqual(records1, records2)
+        self.assertEqual(records2, records3)
+
     def test_tensor_shape(self):
         x = torch.empty(34, 56, 78)
 
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 6c7ea10..50e6ddd 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -123,6 +123,7 @@
     "torch/csrc/jit/script/class_type.cpp",
     "torch/csrc/jit/script/parser.cpp",
     "torch/csrc/jit/script/jit_exception.cpp",
+    "torch/csrc/jit/source_range_serialization.cpp",
     "torch/csrc/jit/testing/file_check.cpp",
     "torch/csrc/jit/import_source.cpp",
     "torch/csrc/jit/hooks_for_testing.cpp",
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 63fef41..c8b10eb 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -11,7 +11,7 @@
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/pickler.h>
-#include <torch/csrc/jit/source_range_serializer.h>
+#include <torch/csrc/jit/source_range_serialization.h>
 
 #include <caffe2/core/types.h>
 #include <caffe2/proto/caffe2_pb.h>
@@ -931,21 +931,14 @@
     // Write out debug records
     torch::RecordRef* debug_record =
         module_def->mutable_torchscript_debug_arena();
-    Pickler p;
-    SourceRangeSerializer srs;
-    p.start();
-    p.startTuple();
-    for (const auto& range : source_ranges) {
-      std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
-                                         srs.serialize(range.range)};
-      p.addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
-    }
-    p.endTuple();
-    p.finish();
+
+    SourceRangePickler source_range_pickler;
+    source_range_pickler.pickle(source_ranges);
+    const auto& range_data = source_range_pickler.get_data();
     std::stringstream debug_filename;
     debug_filename << "debug/" << module_name.str() << ".pkl";
     writer_.writeRecord(
-        debug_filename.str(), p.stack().data(), p.stack().size());
+        debug_filename.str(), range_data.data(), range_data.size());
     debug_record->set_key(debug_filename.str());
   }
 
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 980772f..9f3db62 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -9,6 +9,8 @@
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/pickler.h>
 #include <torch/csrc/jit/script/script_type_parser.h>
+#include <torch/csrc/jit/source_range_serialization.h>
+#include <torch/csrc/jit/source_range_serialization_impl.h>
 
 #include "caffe2/core/common.h"
 #include "caffe2/core/types.h"
@@ -328,6 +330,20 @@
     module.register_attribute(
         attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
   }
+
+  // If present, load in the table of source ranges from the original
+  // generating code.
+  std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
+  if (module_def.has_torchscript_debug_arena()) {
+    at::DataPtr data;
+    size_t size;
+    std::tie(data, size) =
+        reader_.getRecord(module_def.torchscript_debug_arena().key());
+
+    gen_ranges =
+        std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
+  }
+
   if (module_def.has_torchscript_arena()) {
     at::DataPtr data;
     size_t size;
@@ -337,7 +353,8 @@
     auto src = std::make_shared<Source>(
         std::string(static_cast<const char*>(data.get()), size),
         module_def.torchscript_arena().key(),
-        1);
+        1,
+        std::move(gen_ranges));
 
     std::function<void(const std::string&)> import_callback =
         [this](const std::string& qualifier) { importCallback(qualifier); };
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 000e32e..8ccb7e6 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -171,10 +171,13 @@
   SourceRangeStack source_range_stack_ = {SourceRange("")};
 
   struct WithSourceRange {
-    explicit WithSourceRange(SourceRangeStack* stack, SourceRange sr)
-        : stack(stack) {
+    explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
       TORCH_INTERNAL_ASSERT(stack);
-      stack->push_back(std::move(sr));
+      if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
+        stack->push_back(std::move(gen_source.value()));
+      } else {
+        stack->push_back(std::move(n->sourceRange()));
+      }
     }
 
     ~WithSourceRange() {
@@ -190,6 +193,13 @@
     TaggedStringStream(TaggedStringStream&& rhs) = default;
 
     TaggedStringStream& operator<<(const std::string& s) {
+      // This prevents having redundant entries at the same offset,
+      // which can happen for example in printValueList when begin
+      // and end are the empty string.
+      if (s.size() == 0) {
+        return *this;
+      }
+
       if (!ranges_.size() || ranges_.back().range != srs_->back()) {
         ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
       }
@@ -233,6 +243,21 @@
       return ranges_;
     }
 
+    // Write out this TaggedStringStream's text and source ranges to
+    // os and source_ranges_out, respectively. stream_pos gives
+    // the byte offset into the current stream, so we can accurately
+    // record source ranges as byte offsets.
+    void print(
+        std::ostream& os,
+        SourceRangeRecords* source_ranges_out,
+        int64_t stream_pos) {
+      os << str();
+      for (const auto& x : ranges()) {
+        source_ranges_out->push_back(x);
+        source_ranges_out->back().bytes += stream_pos;
+      }
+    }
+
    private:
     std::ostringstream oss_;
     std::vector<TaggedRange> ranges_;
@@ -756,7 +781,7 @@
   }
 
   void printNode(Node* node, bool print_const) {
-    WithSourceRange guard(&source_range_stack_, node->sourceRange());
+    WithSourceRange guard(&source_range_stack_, node);
     // Check for class dependencies. If this node inputs or outputs a class
     // type, we need to add it to our table of dependencies.
     for (const auto input : node->inputs()) {
@@ -1149,8 +1174,7 @@
     Graph& graph = *func.graph();
     used_names_.clear(); // each graph can reuse local names
 
-    WithSourceRange guard(
-        &source_range_stack_, graph.param_node()->sourceRange());
+    WithSourceRange guard(&source_range_stack_, graph.param_node());
 
     indent();
     body_ << "def " << func.name() << "(";
@@ -1250,8 +1274,9 @@
   }
 
   void print(std::ostream& out, SourceRangeRecords& source_ranges_out) {
-    out << getImports() << body_.str();
-    source_ranges_out = body_.ranges();
+    out << getImports();
+    int64_t source_offset = out.tellp();
+    body_.print(out, &source_ranges_out, source_offset);
   }
 };
 
diff --git a/torch/csrc/jit/passes/python_print.h b/torch/csrc/jit/passes/python_print.h
index 9ee0d9f..8bdeaec 100644
--- a/torch/csrc/jit/passes/python_print.h
+++ b/torch/csrc/jit/passes/python_print.h
@@ -12,16 +12,6 @@
 struct Module;
 } // namespace script
 
-// A pair of (byte offset, SourceRange) describing a specific segment
-// of the output stream
-struct TaggedRange {
-  TaggedRange(size_t bytes, SourceRange range)
-      : bytes(bytes), range(std::move(range)) {}
-  size_t bytes;
-  SourceRange range;
-};
-using SourceRangeRecords = std::vector<TaggedRange>;
-
 TORCH_API void PythonPrint(
     std::ostream& out,
     SourceRangeRecords& source_ranges_out,
diff --git a/torch/csrc/jit/source_range.cpp b/torch/csrc/jit/source_range.cpp
index 578e6dd..8a5e573 100644
--- a/torch/csrc/jit/source_range.cpp
+++ b/torch/csrc/jit/source_range.cpp
@@ -1,8 +1,17 @@
 #include <torch/csrc/jit/source_range.h>
+#include <torch/csrc/jit/source_range_serialization.h>
 
 namespace torch {
 namespace jit {
 
+c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
+    const SourceRange& range) {
+  if (!gen_ranges_) {
+    return c10::nullopt;
+  }
+  return gen_ranges_->findSourceRangeThatGenerated(range);
+}
+
 // a range of a shared string 'file_' with
 C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
   const std::string& str = source_->text();
@@ -56,6 +65,13 @@
   out << str.substr(end_line, end_highlight - end_line);
   if (!str.empty() && str.back() != '\n')
     out << "\n";
+  // Retrieve original SourceRange, if present.
+  if (source_) {
+    if (auto orig_source_range = findSourceRangeThatGenerated()) {
+      out << "Compiled from code ";
+      orig_source_range->highlight(out);
+    }
+  }
 }
 
 } // namespace jit
diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h
index 5362457..4176e79 100644
--- a/torch/csrc/jit/source_range.h
+++ b/torch/csrc/jit/source_range.h
@@ -8,6 +8,9 @@
 namespace torch {
 namespace jit {
 
+struct SourceRangeUnpickler;
+struct SourceRange;
+
 // Source represents a code segment. It keeps track of:
 //  - text : the text of the code segment
 //  - filename (optional) : if present, represents the name of the file from
@@ -15,18 +18,25 @@
 //  - starting_line_no : represents the line in the original file where the
 //                       code segment started.
 struct Source {
-  explicit Source(std::string text)
-      : text_(std::move(text)), filename_(c10::nullopt) {
+  explicit Source(
+      std::string text,
+      std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
+      : text_(std::move(text)),
+        filename_(c10::nullopt),
+        starting_line_no_(0),
+        gen_ranges_(std::move(gen_ranges)) {
     calc_line_start_offsets();
   }
 
   Source(
       std::string text,
       c10::optional<std::string> filename,
-      size_t starting_line_no)
+      size_t starting_line_no,
+      std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
       : text_(std::move(text)),
         filename_(std::move(filename)),
-        starting_line_no_(starting_line_no) {
+        starting_line_no_(starting_line_no),
+        gen_ranges_(std::move(gen_ranges)) {
     calc_line_start_offsets();
   }
 
@@ -67,6 +77,9 @@
     return starting_line_no_;
   }
 
+  c10::optional<SourceRange> findSourceRangeThatGenerated(
+      const SourceRange& range);
+
  private:
   void calc_line_start_offsets() {
     size_t pos = 0;
@@ -82,6 +95,8 @@
   // Starting offsets for lines into the source. e.g. line 0 starts at
   // line_starting_offsets_[0], etc.
   std::vector<size_t> line_starting_offsets_;
+
+  std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
 };
 
 // A SourceRange is a view into a Source, that points to a subset of the source,
@@ -139,6 +154,13 @@
   bool operator!=(const SourceRange& rhs) const {
     return !(*this == rhs);
   }
+  
+  c10::optional<SourceRange> findSourceRangeThatGenerated() const {
+    if (!source_) {
+      return c10::nullopt;
+    }
+    return source_->findSourceRangeThatGenerated(*this);
+  }
 
  private:
   std::shared_ptr<Source> source_;
@@ -151,5 +173,15 @@
   return out;
 }
 
+// A pair of (byte offset, SourceRange) describing a specific segment
+// of the output stream
+struct TaggedRange {
+  TaggedRange(size_t bytes, SourceRange range)
+      : bytes(bytes), range(std::move(range)) {}
+  size_t bytes;
+  SourceRange range;
+};
+using SourceRangeRecords = std::vector<TaggedRange>;
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/source_range_serialization.cpp b/torch/csrc/jit/source_range_serialization.cpp
new file mode 100644
index 0000000..f42cbce
--- /dev/null
+++ b/torch/csrc/jit/source_range_serialization.cpp
@@ -0,0 +1,148 @@
+#include <torch/csrc/jit/source_range_serialization.h>
+#include <torch/csrc/jit/source_range_serialization_impl.h>
+
+#include <ATen/core/ivalue.h>
+#include <torch/csrc/jit/pickler.h>
+
+namespace torch {
+namespace jit {
+
+class SourceRangeSerializer {
+ public:
+  // Serialize SourceRange as Tuple[SourceType, int, int]
+  // where SourceType = Tuple[str, Optional[str], int, List[int]],
+  // the serialized form of Source
+  c10::IValue serialize(const SourceRange& sr);
+
+ private:
+  // Serialize Source as Tuple[str, Optional[str], int, List[int]]
+  // This caches serialized sources, since many SourceRanges can
+  // refer to the same one.
+  c10::IValue serialize_source(const std::shared_ptr<Source>& s);
+
+  std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
+};
+
+class SourceRangeDeserializer {
+ public:
+  SourceRange deserialize(const c10::IValue& iv) {
+    auto tup_elems = iv.toTuple()->elements();
+    TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
+    std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
+    int64_t start_ = tup_elems[1].toInt();
+    int64_t end_ = tup_elems[2].toInt();
+    return SourceRange(source_, start_, end_);
+  }
+
+ private:
+  std::shared_ptr<Source> deserialize_source(const c10::IValue& iv) {
+    auto tup = iv.toTuple();
+    if (cached_sources.count(tup)) {
+      return cached_sources.at(tup);
+    }
+
+    auto tup_elems = tup->elements();
+    TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
+    std::string text_ = tup_elems[0].toString()->string();
+    c10::optional<std::string> filename_ =
+        tup_elems[1].toOptional<std::string>();
+    int64_t starting_line_no_ = tup_elems[2].toInt();
+
+    auto source = std::make_shared<Source>(
+        std::move(text_), std::move(filename_), starting_line_no_);
+    cached_sources[tup] = source;
+    return source;
+  }
+
+  std::unordered_map<
+      c10::intrusive_ptr<c10::ivalue::Tuple>,
+      std::shared_ptr<Source>>
+      cached_sources;
+};
+
+c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
+  std::vector<c10::IValue> elements = {
+      serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
+  return c10::ivalue::Tuple::create(std::move(elements));
+}
+
+c10::IValue SourceRangeSerializer::serialize_source(
+    const std::shared_ptr<Source>& s) {
+  if (serialized_sources.count(s)) {
+    return serialized_sources.at(s);
+  }
+  std::vector<c10::IValue> elements{
+      s->text(), s->filename(), (int64_t)s->starting_line_no()};
+  auto serialized = c10::ivalue::Tuple::create(std::move(elements));
+  serialized_sources[s] = serialized;
+  return serialized;
+}
+
+SourceRangePickler::SourceRangePickler()
+    : p(new Pickler()), srs(new SourceRangeSerializer()) {}
+
+void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
+  p->start();
+  p->startTuple();
+  for (const auto& range : ranges) {
+    std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
+                                       srs->serialize(range.range)};
+    p->addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
+  }
+  p->endTuple();
+  p->finish();
+}
+
+const std::vector<char>& SourceRangePickler::get_data() {
+  return p->stack();
+}
+
+ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
+    at::DataPtr&& data,
+    size_t size)
+    : data(std::move(data)),
+      size(size),
+      deserializer(new SourceRangeDeserializer()),
+      unpickled_records(nullptr) {}
+
+void ConcreteSourceRangeUnpickler::unpickle() {
+  if (unpickled_records) {
+    return;
+  }
+
+  Unpickler up(data.get(), size, nullptr);
+  auto ivalues = up.parse_ivalue_list();
+
+  unpickled_records = std::make_shared<SourceRangeRecords>();
+  for (auto& val : ivalues) {
+    auto tup_elems = val.toTuple()->elements();
+    int64_t offset = tup_elems[0].toInt();
+    auto source_range = deserializer->deserialize(tup_elems[1]);
+    unpickled_records->emplace_back(offset, std::move(source_range));
+  }
+}
+
+c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
+    findSourceRangeThatGenerated(const SourceRange& range) {
+  unpickle();
+
+  auto query = TaggedRange(range.start(), SourceRange{""});
+  auto entry = std::upper_bound(
+      unpickled_records->begin(),
+      unpickled_records->end(),
+      query,
+      [](const TaggedRange& a, const TaggedRange& b) -> bool {
+        return a.bytes < b.bytes;
+      });
+
+  // NB: must decrement iterator since upper_bound finds the element
+  // *greater than* the query.
+  if (entry != unpickled_records->begin()) {
+    return (entry - 1)->range;
+  }
+
+  return c10::nullopt;
+}
+
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serialization.h b/torch/csrc/jit/source_range_serialization.h
new file mode 100644
index 0000000..349ef16
--- /dev/null
+++ b/torch/csrc/jit/source_range_serialization.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include <c10/core/Allocator.h>
+#include <torch/csrc/jit/source_range.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace c10 {
+struct IValue;
+}
+
+namespace torch {
+namespace jit {
+
+class Pickler;
+class SourceRangeSerializer;
+class SourceRangeDeserializer;
+
+class SourceRangePickler {
+ public:
+  SourceRangePickler();
+
+  void pickle(const SourceRangeRecords& ranges);
+
+  const std::vector<char>& get_data();
+
+ private:
+  std::shared_ptr<Pickler> p;
+  std::shared_ptr<SourceRangeSerializer> srs;
+};
+
+class SourceRangeUnpickler {
+ public:
+  virtual c10::optional<SourceRange> findSourceRangeThatGenerated(
+      const SourceRange& range) = 0;
+
+  virtual ~SourceRangeUnpickler() {}
+};
+
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serialization_impl.h b/torch/csrc/jit/source_range_serialization_impl.h
new file mode 100644
index 0000000..1e1bb10
--- /dev/null
+++ b/torch/csrc/jit/source_range_serialization_impl.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <torch/csrc/jit/pickler.h>
+#include <torch/csrc/jit/source_range_serialization.h>
+
+namespace torch {
+namespace jit {
+
+// Do this clownyness with virtual functions because of the split
+// between ATen core and torch
+
+class ConcreteSourceRangeUnpickler : public SourceRangeUnpickler {
+ public:
+  ConcreteSourceRangeUnpickler(at::DataPtr&& data, size_t size);
+
+  c10::optional<SourceRange> findSourceRangeThatGenerated(
+      const SourceRange& range) override;
+
+ private:
+  at::DataPtr data;
+  size_t size;
+
+  void unpickle();
+
+  std::shared_ptr<SourceRangeDeserializer> deserializer;
+  std::shared_ptr<SourceRangeRecords> unpickled_records;
+};
+
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/source_range_serializer.h b/torch/csrc/jit/source_range_serializer.h
deleted file mode 100644
index 8c10e18..0000000
--- a/torch/csrc/jit/source_range_serializer.h
+++ /dev/null
@@ -1,39 +0,0 @@
-#pragma once
-
-#include <ATen/core/ivalue.h>
-#include <torch/csrc/jit/source_range.h>
-
-namespace torch {
-namespace jit {
-
-class SourceRangeSerializer {
- public:
-  // Serialize SourceRange as Tuple[SourceType, int, int]
-  // where SourceType = Tuple[str, Optional[str], int, List[int]],
-  // the serialized form of Source
-  c10::IValue serialize(const SourceRange& sr) {
-    std::vector<c10::IValue> elements = {
-        serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
-    return c10::ivalue::Tuple::create(std::move(elements));
-  }
-
- private:
-  // Serialize Source as Tuple[str, Optional[str], int, List[int]]
-  // This caches serialized sources, since many SourceRanges can
-  // refer to the same one.
-  c10::IValue serialize_source(const std::shared_ptr<Source>& s) {
-    if (serialized_sources.count(s)) {
-      return serialized_sources.at(s);
-    }
-    std::vector<c10::IValue> elements{
-        s->text(), s->filename(), (int64_t)s->starting_line_no()};
-    auto serialized = c10::ivalue::Tuple::create(std::move(elements));
-    serialized_sources[s] = serialized;
-    return serialized;
-  }
-
-  std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
-};
-
-} // namespace jit
-} // namespace torch
\ No newline at end of file