[dynamo] Add missing fields for THPPyInterpreterFrame. (#103227)

Fixes https://github.com/pytorch/pytorch/issues/103210
Test Plan:
Before the fix:
```
pytest test/dynamo/test_export.py -k suppress_errors
```
got result:
```
  File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/eval_frame.py", line 448, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 511, in _compile
    exception_handler(e, code, frame)
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 216, in exception_handler
    log.error(format_error_msg(e, code, record_filename, frame))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/exc.py", line 248, in format_error_msg
    stack_above_dynamo = filter_stack(extract_stack(frame))
                                      ^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 231, in extract_stack
    stack = StackSummary.extract(walk_stack(f), limit=limit)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 393, in extract
    return klass._extract_from_extended_frame_gen(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 416, in _extract_from_extended_frame_gen
    for f, (lineno, end_lineno, colno, end_colno) in frame_gen:
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 390, in extended_frame_gen
    for f, lineno in frame_gen:
  File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 334, in walk_stack
    yield f, f.f_lineno
             ^^^^^^^^^^
AttributeError: 'torch._C.dynamo.eval_frame._PyInterpreterFrame' object has no attribute 'f_lineno'
```

After the fix:
```
pytest test/dynamo/test_export.py -k suppress_errors -s
```
Got Result:
```
  File "/data/users/zhxchen17/pytorch/torch/_dynamo/exc.py", line 135, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: map() operator doesn't support scalar or zero-sized tensors during
tracing.

========== The above exception occurred while processing the following code ==========

  File "/data/users/zhxchen17/pytorch/test/dynamo/test_export.py", line 3043, in forward
    def forward(self, xs):
  File "/data/users/zhxchen17/pytorch/test/dynamo/test_export.py", line 3047, in forward
    return map(body, xs)

==========
unimplemented [("map() operator doesn't support scalar or zero-sized tensors during tracing.", 1)]
.

=============================== 1 passed, 133 deselected in 4.60s ================================

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103227
Approved by: https://github.com/williamwen42
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index fe55e0f..efdf7d5 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -3405,6 +3405,25 @@
 
         self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container)))
 
+    @config.patch(suppress_errors=True)
+    @config.patch(verbose=True)
+    def test_export_with_map_zero_sized_tensor_suppress_errors(self):
+        from functorch.experimental.control_flow import map
+
+        class Module(torch.nn.Module):
+            def forward(self, xs):
+                def body(x):
+                    return x + 1
+
+                return map(body, xs)
+
+        mod = Module()
+        xs = torch.randn(0, 2)
+        with self.assertRaises(
+            torch._dynamo.exc.Unsupported,
+        ):
+            out_graph, _ = torch._dynamo.export(mod, xs)
+
 
 common_utils.instantiate_parametrized_tests(ExportTests)
 
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 9365fc0..da44bdd 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -59,6 +59,24 @@
   return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
 }
 
+static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) {
+  if (!self->frame->frame_obj) {
+    return PyLong_FromLong(self->frame->f_code->co_firstlineno);
+  }
+  int lineno = PyFrame_GetLineNumber(self->frame->frame_obj);
+  if (lineno < 0) {
+    Py_RETURN_NONE;
+  }
+  return PyLong_FromLong(lineno);
+}
+
+static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) {
+  if (!self->frame->frame_obj) {
+    Py_RETURN_NONE;
+  }
+  return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
+}
+
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
 static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
     {"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
@@ -69,6 +87,8 @@
     {"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
     {"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
     {"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
+    {"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
+    {"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
     {NULL}};
 
 static PyTypeObject THPPyInterpreterFrameType = {