Create Cache for Fusion Reuse in NVFuser in Python Frontend for Primtorch (#83267)

This PR does the following:

- Replaces the `FusionOwner` with a `FusionCache` and `FusionInterface`.  The `FusionCache` is a singleton that contains a cache of Fusions based on the `FusionDefinition`.  It replaces the TorchScript graph caching that looked up a Fusion based on a stringified and canonicalized representation of the TorchScript graph with a prefix tree of statements in the `FusionDefinition`.  The `FusionInterface` is an object that represents a Fusion in python.  It can also query the cache based on id.
- The ability to print out a mechanically derived definition, in python, for the user to use when debugging was added.
- Replaces the python `examples` directory with true python tests under `test/test_nvfuser_frontend.py`.
- Adds a set of C++ tests under the `test` directory to verify the `FusionCache`, `FusionDefinition`, and parts of the `RecordFunctor` child classes.
- Adds a README file to explain how to use the Python Frontend

While there are 3,000+ line edits, the bulk of the changes were repetitive line changes to the python bindings for each operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83267
Approved by: https://github.com/jjsjann123, https://github.com/davidberard98
diff --git a/build_variables.bzl b/build_variables.bzl
index 21d9755..bb25892 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -729,6 +729,9 @@
     "torch/csrc/jit/codegen/cuda/partial_split_map.cpp",
     "torch/csrc/jit/codegen/cuda/partition.cpp",
     "torch/csrc/jit/codegen/cuda/predicate_compute.cpp",
+    "torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.cpp",
+    "torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp",
+    "torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.cpp",
     "torch/csrc/jit/codegen/cuda/register_interface.cpp",
     "torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
     "torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
@@ -891,7 +894,6 @@
     "torch/csrc/autograd/python_variable_indexing.cpp",
     "torch/csrc/jit/backends/backend_init.cpp",
     "torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp",
-    "torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp",
     "torch/csrc/jit/python/init.cpp",
     "torch/csrc/jit/passes/onnx.cpp",
     "torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt
index b8d9a96..13f18c4 100644
--- a/test/cpp/jit/CMakeLists.txt
+++ b/test/cpp/jit/CMakeLists.txt
@@ -96,6 +96,9 @@
 )
 
 if(USE_CUDA)
+  list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp)
+  list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp)
+  list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp)
   list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu.cpp)
   list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp)
   list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
diff --git a/test/test_nvfuser_frontend.py b/test/test_nvfuser_frontend.py
new file mode 100644
index 0000000..28c5894
--- /dev/null
+++ b/test/test_nvfuser_frontend.py
@@ -0,0 +1,348 @@
+# Owner(s): ["module: nvfuser"]
+
+import unittest
+from typing import List
+
+import torch
+from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase
+from torch.testing._internal.jit_utils import RUN_CUDA
+import torch._refs as refs
+import torch._prims as prims
+
+# Will only create the _nvfuser module if CUDA is available
+if hasattr(torch._C, "_nvfuser"):
+    from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType
+
+RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
+
+def is_pre_volta():
+    if not RUN_NVFUSER:
+        return False
+    prop = torch.cuda.get_device_properties(torch.cuda.current_device())
+    return prop.major < 7
+
+@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
+@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.")
+class TestNvFuserFrontend(TestCase):
+    def test_basic(self) :
+        input1 = torch.ones(2, 4, 8, device='cuda')
+        input2 = torch.ones(2, 4, 8, device='cuda')
+        fc = FusionCache.get()
+        before_fusions = fc.num_fusions()
+
+        fs1 = Fusion()
+        with FusionDefinition(fs1) as fd :
+            t0 = fd.define_tensor(3)
+            t1 = fd.define_tensor(3)
+            c0 = fd.define_constant(3.0)
+
+            t2 = fd.ops.add(t0, t1)
+            t3 = fd.ops.mul(t2, c0)
+            t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
+
+            fd.add_output(t4)
+
+        # Expected Output is a tensor of 48's
+        nvf_out1 = fs1.execute([input1, input2])[0]
+
+        # Create a new fusion with the same definition, it should hit the cache!
+        fs2 = Fusion()
+        with FusionDefinition(fs2) as fd :
+            t0 = fd.define_tensor(3)
+            t1 = fd.define_tensor(3)
+            c0 = fd.define_constant(3.0)
+
+            t2 = fd.ops.add(t0, t1)
+            t3 = fd.ops.mul(t2, c0)
+            t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
+
+            fd.add_output(t4)
+
+        nvf_out2 = fs2.execute([input1, input2])[0]
+
+        # Check there is still only 1 cache entry
+        fc = FusionCache.get()
+        self.assertEqual(fc.num_fusions() - before_fusions, 1)
+
+        # Create a fusion from a fusion id and make sure it executes!
+        fs3 = Fusion(fs2.id())
+        nvf_out3 = fs3.execute([input1, input2])[0]
+
+        eager_out = torch.sum((input1 + input2) * 3.0, dim=-1)
+        self.assertEqual(eager_out, nvf_out1)
+        self.assertEqual(eager_out, nvf_out2)
+        self.assertEqual(eager_out, nvf_out3)
+
+    def test_basic_fp16(self) :
+        fs = Fusion()
+        with FusionDefinition(fs) as fd :
+            t0 = fd.define_tensor(3, DataType.Half)
+            t1 = fd.define_tensor(3, DataType.Half)
+            c0 = fd.define_constant(3.0)
+
+            t2 = fd.ops.add(t0, t1)
+            t3 = fd.ops.mul(t2, c0)
+            t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
+
+            t5 = fd.ops.cast(t4, DataType.Half)
+            fd.add_output(t5)
+
+        input1 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16)
+        input2 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16)
+
+        # Expected Output is a tensor of 48's
+        nvf_out = fs.execute([input1, input2])[0]
+        eager_out = torch.sum((input1 + input2) * 3.0, dim=-1)
+        self.assertEqual(eager_out, nvf_out)
+
+    def test_cast_double_to_half(self) :
+        fs = Fusion()
+        with FusionDefinition(fs) as fd :
+            t0 = fd.define_tensor(2, DataType.Double)
+            t1 = fd.define_tensor(2, DataType.Double)
+
+            t0h = fd.ops.cast(t0, DataType.Half)
+            t1h = fd.ops.cast(t1, DataType.Half)
+            t2 = fd.ops.add(t0h, t1h)
+            t3 = fd.ops.relu(t2)
+            t4 = fd.ops.cast(t3, DataType.Half)
+
+            fd.add_output(t4)
+
+        input1 = torch.randn(2, 4, device='cuda', dtype=torch.float64)
+        input2 = torch.randn(2, 4, device='cuda', dtype=torch.float64)
+
+        nvf_out = fs.execute([input1, input2])[0]
+        eager_out = torch.relu(input1.to(torch.half) + input2.to(torch.half))
+        self.assertEqual(eager_out, nvf_out)
+
+    def test_promote_to_double(self) :
+        fs = Fusion()
+
+        with FusionDefinition(fs) as fd :
+            t0 = fd.define_tensor(2, DataType.Half)
+            t1 = fd.define_tensor(2, DataType.Double)
+
+            t2 = fd.ops.add(t0, t1)
+            t5 = fd.ops.relu(t2)
+
+            fd.add_output(t5)
+
+        input1 = torch.randn(2, 4, device='cuda', dtype=torch.float16)
+        input2 = torch.randn(2, 4, device='cuda', dtype=torch.float64)
+
+        nvf_out = fs.execute([input1, input2])[0]
+        eager_out = torch.relu(input1 + input2)
+        self.assertEqual(eager_out, nvf_out)
+
+    def test_implicit_broadcast_input(self) :
+        fs = Fusion()
+        with FusionDefinition(fs) as fd :
+            t0 = fd.define_tensor(1)
+            t1 = fd.define_tensor(3)
+
+            t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1])
+            t2 = fd.ops.add(t0_b, t1)
+
+            fd.add_output(t2)
+
+        input1 = torch.randn(3, device='cuda')
+        input2 = torch.randn(2, 3, 4, device='cuda')
+
+        nvf_out = fs.execute([input1, input2])[0]
+        eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2)
+        self.assertEqual(eager_out, nvf_out)
+
+    def test_explicit_broadcast_input(self) :
+        input1 = torch.randn(1, 1, 4, device='cuda')
+        input2 = torch.randn(2, 3, 4, device='cuda')
+
+        fs = Fusion()
+        with FusionDefinition(fs) as fd :
+            t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
+            t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())
+
+            t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
+            t2 = fd.ops.add(t0_b, t1)
+
+            fd.add_output(t2)
+
+        nvf_out = fs.execute([input1, input2])[0]
+        eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2)
+        self.assertEqual(eager_out, nvf_out)
+
+    def test_broadcast_mixing(self) :
+        fs = Fusion()
+        with FusionDefinition(fs) as fd :
+            t0 = fd.define_tensor([3, 1], [1, 1])
+            t1 = fd.define_tensor(1)
+
+            t1_b = fd.ops.broadcast_in_dim(t1, [3, 3], [0])
+            t2 = fd.ops.add(t0, t1_b)
+
+            fd.add_output(t2)
+
+        input1 = torch.randn(3, 1, device='cuda')
+        input2 = torch.randn(3, device='cuda')
+
+        nvf_out = fs.execute([input1, input2])[0]
+        eager_out = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0]))
+        self.assertEqual(eager_out, nvf_out)
+
+    def test_prim_layer_norm_fwd(self) :
+        def primitive_definition(
+            inputs: torch.Tensor,
+            weight: torch.Tensor,
+            bias: torch.Tensor,
+            normalization_axis: int,
+            keepdim: bool,
+        ) -> torch.Tensor:
+            mean = inputs.mean(normalization_axis, keepdim=keepdim)
+            diff = inputs - mean
+            diff_sq = diff * diff
+            var = diff_sq.mean(normalization_axis, keepdim=keepdim)
+            pre_shift_scale_norm_output = (inputs - mean) / torch.sqrt(var + 1e-12)
+            norm_output = weight * pre_shift_scale_norm_output + bias
+            return norm_output
+
+        def nvfuser_fusion(
+            fd: FusionDefinition,
+            normalization_axis: int,
+            norm_size: int,
+            input_shape: List[int],
+            eps: float,
+            keepDim: bool
+        ) -> None :
+            inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
+            weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
+            bias = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
+            sum0 = fd.ops.sum(inputs, axes=[normalization_axis], keepdim=keepDim)
+            norm_const = fd.define_constant(norm_size)
+            mean = fd.ops.div(sum0, norm_const)
+            diff = fd.ops.sub(inputs, mean)
+            diff_sq = fd.ops.mul(diff, diff)
+            sum1 = fd.ops.sum(diff_sq, axes=[normalization_axis], keepdim=keepDim)
+            var = fd.ops.div(sum1, norm_const)
+            eps_const = fd.define_constant(eps)
+            var_eps = fd.ops.add(var, eps_const)
+            invstd = fd.ops.rsqrt(var_eps)
+            pre_scale_bias = fd.ops.mul(diff, invstd)
+            weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2])
+            scale = fd.ops.mul(pre_scale_bias, weights_bcast)
+            bias_bcast = fd.ops.broadcast_in_dim(bias, output_shape=input_shape, broadcast_dims=[2])
+            out = fd.ops.add(scale, bias_bcast)
+            fd.add_output(out)
+            fd.add_output(mean)
+            fd.add_output(invstd)
+
+        def nvfuser_fusion_var_mean(
+            fd: FusionDefinition,
+            normalization_axis: int,
+            norm_size: int,
+            input_shape: List[int],
+            eps: float,
+            keepDim: bool
+        ) -> None :
+            inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
+            weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
+            bias = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
+            var, mean = fd.ops.var_mean(inputs, axes=[normalization_axis], correction=0, keepdim=keepDim)
+            eps_const = fd.define_constant(eps)
+            var_eps = fd.ops.add(var, eps_const)
+            invstd = fd.ops.rsqrt(var_eps)
+            diff = fd.ops.sub(inputs, mean)
+            pre_scale_bias = fd.ops.mul(diff, invstd)
+            weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2])
+            scale = fd.ops.mul(pre_scale_bias, weights_bcast)
+            bias_bcast = fd.ops.broadcast_in_dim(bias, output_shape=input_shape, broadcast_dims=[2])
+            out = fd.ops.add(scale, bias_bcast)
+            fd.add_output(out)
+            fd.add_output(mean)
+            fd.add_output(invstd)
+
+        input_size = [64, 128, 1024]
+        dtype = torch.float32
+        device = 'cuda'
+        inputs = torch.randn(*input_size, device=device, requires_grad=True)
+        weights = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
+        biases = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
+        fc = FusionCache.get()
+        before_fusions = fc.num_fusions()
+
+        for _ in range(5) :
+            nvf_fusion = Fusion()
+            with FusionDefinition(nvf_fusion) as fd:
+                nvfuser_fusion(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True)
+            nvf_out = nvf_fusion.execute([inputs, weights, biases])
+
+        for _ in range(5) :
+            nvf_var_mean_fusion = Fusion()
+            with FusionDefinition(nvf_var_mean_fusion) as fd:
+                nvfuser_fusion_var_mean(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True)
+            nvf_var_mean_out = nvf_var_mean_fusion.execute([inputs, weights, biases])
+
+        for _ in range(5) :
+            eager_out = primitive_definition(inputs, weights, biases, 2, True)
+
+        self.assertEqual(eager_out, nvf_out[0])
+        self.assertEqual(eager_out, nvf_var_mean_out[0])
+        fusion_cache = FusionCache.get()
+        self.assertEqual(fc.num_fusions() - before_fusions, 2)
+
+    def test_prim_rms_norm_fwd(self) :
+        def primitive_definition(
+            inputs: torch.Tensor,
+            weight: torch.Tensor,
+            normalization_axis: int,
+            keepdim: bool,
+        ) -> torch.Tensor:
+            var = inputs.mul(inputs).mean(normalization_axis, keepdim)
+            pre_shift_scale_norm_output = inputs / torch.sqrt(var + 1e-12)
+            norm_output = weight * pre_shift_scale_norm_output
+            return norm_output
+
+        def nvfuser_fusion(
+            fd: FusionDefinition,
+            normalization_axis: int,
+            norm_size: int,
+            input_shape: List[int],
+            eps: float,
+            keepDim: bool
+        ) -> None :
+            inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
+            weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
+            inputs_sq = fd.ops.mul(inputs, inputs)
+            sum0 = fd.ops.sum(inputs_sq, axes=[normalization_axis], keepdim=keepDim)
+            norm_const = fd.define_constant(norm_size)
+            var = fd.ops.div(sum0, norm_const)
+            eps_const = fd.define_constant(eps)
+            var_eps = fd.ops.add(var, eps_const)
+            invstd = fd.ops.rsqrt(var_eps)
+            pre_scale = fd.ops.mul(inputs, invstd)
+            weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2])
+            out = fd.ops.mul(pre_scale, weights_bcast)
+            fd.add_output(out)
+            fd.add_output(invstd)
+
+        input_size = [64, 128, 1024]
+        dtype = torch.float32
+        device = 'cuda'
+        inputs = torch.randn(*input_size, device=device, requires_grad=True)
+        weights = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
+        fc = FusionCache.get()
+        before_fusions = fc.num_fusions()
+
+        for _ in range(5) :
+            nvf_fusion = Fusion()
+            with FusionDefinition(nvf_fusion) as fd:
+                nvfuser_fusion(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True)
+            nvf_out = nvf_fusion.execute([inputs, weights])
+
+        for _ in range(5) :
+            eager_out = primitive_definition(inputs, weights, 2, True)
+
+        self.assertEqual(eager_out, nvf_out[0])
+        self.assertEqual(fc.num_fusions() - before_fusions, 1)
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/torch/csrc/jit/codegen/cuda/instrumentation.h b/torch/csrc/jit/codegen/cuda/instrumentation.h
index b929fff..ef89fcd 100644
--- a/torch/csrc/jit/codegen/cuda/instrumentation.h
+++ b/torch/csrc/jit/codegen/cuda/instrumentation.h
@@ -31,7 +31,7 @@
 //! An easy way to view traces is to type `about://tracing` in Chrome or
 //! Chromium.
 //!
-class Trace : public NonCopyable {
+class TORCH_CUDA_CU_API Trace : public NonCopyable {
  public:
   using Clock = std::chrono::steady_clock;
 
@@ -73,7 +73,7 @@
 
 //! \internal Automatic scope for a perf marker
 //!   (normally used through the FUSER_PERF_SCOPE macro)
-class TraceScope : public NonCopyable {
+class TORCH_CUDA_CU_API TraceScope : public NonCopyable {
  public:
   explicit TraceScope(const char* event_name) : event_name_(event_name) {
     Trace::instance()->beginEvent(event_name_);
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/README.md b/torch/csrc/jit/codegen/cuda/python_frontend/README.md
new file mode 100644
index 0000000..7f3364e
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/README.md
@@ -0,0 +1,138 @@
+# nvFuser Python Frontend
+
+This frontend allows for a user to describe the set of operations for nvFuser to fuse via 1 or more kernels.  This frontend is intended to be an integration point with PyTorch or standalone applications.
+
+# Usage
+
+## Example 1 - Define and Execute a Fusion
+
+```python
+import torch
+from torch._C._nvfuser import Fusion, FusionDefinition, DataType
+
+fs = Fusion()
+with FusionDefinition(fs) as fd :
+    t0 = fd.define_tensor(symbolic_sizes=[-1, 1, -1],
+                          contiguous=[True, True, True],
+                          dtype=DataType.Float)
+    t1 = fd.define_tensor(3)
+    c0 = fd.define_constant(3.0)
+
+    t2 = fd.ops.add(t0, t1)
+    t3 = fd.ops.mul(t2, c0)
+    t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
+
+    fd.add_output(t4)
+
+input1 = torch.ones(2, 1, 8, device='cuda')
+input2 = torch.ones(2, 4, 8, device='cuda')
+
+nvf_out = fs.execute([input1, input2])[0]
+```
+
+## Example 2 - Lookup and Execute a `Fusion` Based on Id
+
+```python
+fid = 0
+fs = Fusion(fid)
+
+input1 = torch.ones(2, 1, 8, device='cuda')
+input2 = torch.ones(2, 4, 8, device='cuda')
+
+nvf_out = fs.execute([input1, input2])[0]
+```
+
+## Components
+
+### `Fusion` - Represents a Fusion
+#### `Fusion` Methods
+* `defined()`: Allows you to query if the `Fusion` is already defined and can be executed.
+* `execute([inputs])`:  Allows you to execute the currently defined fusion with a list of given inputs and returns a list of tensors.
+* `id()`: Returns the fusion id for a given `Fusion`.
+* `print()`: Prints the low level IR for the currently defined fusion.
+
+### `FusionDefiniton` Context Manager - Interface for Defining Fusions
+
+#### Defining Input Tensors
+_All intermediate tensors are created by operations.  Constant tensors do not exist._
+
+There are 3 ways to define tensors that will be enumerated below.
+
+##### 1.) Defining tensors by the number of input dimensions only
+This interface tells nvFuser that the tensor has a given number of symbolic dimensions that are not necessarily contiguous in memory.  The user also has the ability to specify a data type.  The default type is `Float`.
+```python
+t0 = fd.define_tensor(3)
+t1 = fd.define_tensor(3, DataType.Half)
+```
+
+##### 2.) Defining tensors by a list of concrete sizes and a list of strides
+The `sizes` parameter defines the number of dimensions and the size of each dimension.  The `strides` parameter has to have the same number of dimensions as the `sizes` parameter.
+nvFuser translates the concrete sizes and strides into symbolic sizes and contiguity information that can be directly defined via the next way to define tensors.  This allows the user to directly take a Pytorch defined tensor and query its sizes and strides in order to apply them in the definition.
+```python
+t0 = fd.define_tensor(sizes=[2, 4, 6], strides=[24, 6, 1], dtype=DataType.Half)
+```
+
+##### 3.) Defining tensors by a list of symbolic sizes and a list of contiguity information
+The list of symbolic sizes defines the number of dimensions and `-1` is given for each dimension unless it is a broadcast dimension that is defined with a `1`.  The contiguity information is viewed from right to left.  A `True` definition indicates the current dimension is contiguous with the dimension to its right.
+
+```python
+t0 = fd.define_tensor(symbolic_sizes=[-1, 1, -1], contiguous=[True, True, True], dtype=DataType.Float)
+```
+
+#### Defining Input Scalars
+_All intermediate scalars, except for constants, are created by operations._
+
+The only thing the user has to define for a scalar is its type.
+
+```python
+s0 = fd.define_scalar(dtype=DataType.Half)
+```
+
+#### Defining Constant Scalars
+
+Constants can be of types: `Bool`, `ComplexDouble`, `Double`, or `Int`.  The definition only takes a constant and the type is inferred by the constant.
+
+```python
+c0 = fd.define_constant(3.0)
+```
+
+#### Defining Operations
+
+Operators are added with the following notation:
+```python
+output = fd.ops.foo(arg1, ... )
+```
+You can see a supported list of operations with the following query:
+```python
+python -c "from torch._C._nvfuser import FusionDefinition; help(FusionDefinition.Operators)"
+```
+#### Notating Outputs
+
+The `FusionDefintion` `add_output` method is used to indicate an intermediate is an output to the fusion.
+
+```python
+add_output(output: Tensor)
+# or
+add_output(output: Scalar)
+```
+
+# Debug Information
+**Query a list of supported operations:**
+```python
+python -c "from torch._C._nvfuser import FusionDefinition; help(FusionDefinition.Operators)"
+```
+**View the fusion definitions that are executed by setting an environment variable:**
+```python
+export PYTORCH_NVFUSER_DUMP=python_definition
+```
+Example Output:
+```python
+def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
+    T0 = fd.define_tensor(symbolic_sizes=[-1, 1, -1], contiguous=[True, True, True], dtype=DataType.Float)
+    T1 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[False, False, False], dtype=DataType.Float)
+    S2 = fd.define_constant(3.00000)
+    T3 = fd.ops.add(T0, T1)
+    T4 = fd.ops.mul(T3, S2)
+    T5 = fd.ops.sum(T4, axes=[-1], keepdim=False, dtype=DataType.Float)
+    fd.add_output(T5)
+```
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py
deleted file mode 100644
index b3ce49d..0000000
--- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import torch
-
-from torch._C._nvfuser import Fusion, FusionDefinition, DataType
-
-# Construct and Define Fusion
-fusion = Fusion()
-
-with FusionDefinition(fusion) as fd :
-    t0 = fd.define_tensor(2, DataType.Double)
-    t1 = fd.define_tensor(2, DataType.Double)
-
-    t0h = fd.ops.cast(t0, DataType.Half)
-    t1h = fd.ops.cast(t1, DataType.Half)
-    t2 = fd.ops.add(t0h, t1h)
-    t3 = fd.ops.relu(t2)
-
-    fd.add_output(t3)
-
-fusion.print_ir()
-
-# Execute Fusion
-input1 = torch.ones(2, 4, device='cuda', dtype=torch.float64)
-input2 = torch.ones(2, 4, device='cuda', dtype=torch.float64)
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    outputs = fusion.execute([input1, input2])
-
-print(outputs[0])
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py
deleted file mode 100644
index d5f7070..0000000
--- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import torch
-
-from torch._C._nvfuser import Fusion, FusionDefinition, DataType
-
-# Construct and Define Fusion
-fusion = Fusion()
-
-with FusionDefinition(fusion) as fd :
-    t0 = fd.define_tensor(2, DataType.Half)
-    t1 = fd.define_tensor(2, DataType.Double)
-
-    t2 = fd.ops.add(t0, t1)
-    t5 = fd.ops.relu(t2)
-
-    fd.add_output(t5)
-
-fusion.print_ir()
-
-# Execute Fusion
-input1 = torch.ones(2, 4, device='cuda', dtype=torch.float16)
-input2 = torch.ones(2, 4, device='cuda', dtype=torch.float64)
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    outputs = fusion.execute([input1, input2])
-
-print(outputs[0])
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py
deleted file mode 100644
index 2bd236c..0000000
--- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import torch
-from torch._C._nvfuser import Fusion, FusionDefinition, DataType
-
-# Construct and Define Fusion
-fusion = Fusion()
-
-with FusionDefinition(fusion) as fd :
-    t0 = fd.define_tensor(3)
-    t1 = fd.define_tensor(3)
-    s0 = fd.define_scalar()
-
-    c0 = fd.define_constant(3.0)
-
-    t2 = fd.ops.add(t0, t1)
-    t3 = fd.ops.mul(t2, c0)
-    t4 = fd.ops.atan2(t3, s0)
-    t5 = fd.ops.relu(t4)
-    t6 = fd.ops.sum(t5, [-1], False, DataType.Float)
-    t7 = fd.ops.isfinite(t6)
-
-    fd.add_output(t6)
-    fd.add_output(t7)
-
-fusion.print_ir()
-
-# Execute Fusion
-input1 = torch.ones(2, 4, 8, device='cuda')
-input2 = torch.ones(2, 4, 8, device='cuda')
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    outputs = fusion.execute([input1, input2, 2.0])
-
-print(outputs[0])
-print(outputs[1])
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py
deleted file mode 100644
index 06733db..0000000
--- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import torch
-
-from torch._C._nvfuser import Fusion, FusionDefinition
-import torch._prims as prims
-import torch._refs as refs
-
-# Construct and Define Fusion
-fusion1 = Fusion()
-
-with FusionDefinition(fusion1) as fd :
-    t0 = fd.define_tensor(1)
-    t1 = fd.define_tensor(3)
-
-    t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1])
-    t2 = fd.ops.add(t0_b, t1)
-
-    fd.add_output(t2)
-
-fusion1.print_ir()
-
-# Execute Fusion
-input1 = torch.randn(3, device='cuda')
-input2 = torch.randn(2, 3, 4, device='cuda')
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    o = fusion1.execute([input1, input2])[0]
-
-assert(o.shape == torch.Size([2, 3, 4]))
-
-# Reference in prim torch
-ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2)
-assert(ref_o.allclose(o))
-assert(ref_o.shape == o.shape)
-
-fusion2 = Fusion()
-
-input1 = torch.randn(1, 1, 4, device='cuda')
-input2 = torch.randn(2, 3, 4, device='cuda')
-
-with FusionDefinition(fusion2) as fd :
-    t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
-    t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())
-
-    t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
-    t2 = fd.ops.add(t0_b, t1)
-
-    fd.add_output(t2)
-
-fusion2.print_ir()
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    o = fusion2.execute([input1, input2])[0]
-
-assert(o.shape == torch.Size([2, 3, 4]))
-
-# Reference in prim torch
-ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2)
-assert(ref_o.allclose(o))
-assert(ref_o.shape == o.shape)
-
-# Construct and Define Fusion
-fusion3 = Fusion()
-
-with FusionDefinition(fusion3) as fd :
-    # t0 = fd.define_tensor(2)
-    t0 = fd.define_tensor([3, 1], [1, 1])
-    t1 = fd.define_tensor(1)
-
-    t1_b = fd.ops.broadcast_in_dim(t1, [3, 3], [0])  # 1 -> 0
-    t2 = fd.ops.add(t0, t1_b)
-
-    fd.add_output(t2)
-
-fusion3.print_ir()
-
-# Execute Fusion
-input1 = torch.randn(3, 1, device='cuda')
-input2 = torch.randn(3, device='cuda')
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    o = fusion3.execute([input1, input2])[0]
-
-assert(o.shape == torch.Size([3, 3]))
-
-# Reference in prim torch
-ref_o = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0]))
-assert(ref_o.allclose(o))
-assert(ref_o.shape == o.shape)
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py
deleted file mode 100644
index 55fc258..0000000
--- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import torch
-
-from torch._C._nvfuser import Fusion, FusionDefinition, DataType
-
-# Construct and Define Fusion
-fusion = Fusion()
-
-with FusionDefinition(fusion) as fd :
-    t0 = fd.define_tensor(3, DataType.Half)
-    t1 = fd.define_tensor(1, DataType.Half)
-    s0 = fd.define_scalar()
-
-    c0 = fd.define_constant(3.0)
-
-    t2 = fd.ops.add(t0, t1)
-    t3 = fd.ops.mul(t2, c0)
-    t4 = fd.ops.mul(t3, s0)
-    t5 = fd.ops.relu(t4)
-    t6 = fd.ops.sum(t5, [-1], False, DataType.Float)
-
-    t7 = fd.ops.cast(t6, DataType.Half)
-    fd.add_output(t7)
-
-fusion.print_ir()
-
-# Execute Fusion
-input1 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16)
-input2 = torch.ones(8, device='cuda', dtype=torch.float16)
-
-# Kernel compilation should be cached for the 2nd iteration
-# with input tensors of the same shape
-for _ in range(5) :
-    outputs = fusion.execute([input1, input2, 2.0])
-
-print(outputs[0])
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.cpp
new file mode 100644
index 0000000..46a91c7
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.cpp
@@ -0,0 +1,142 @@
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
+#include <mutex>
+
+namespace nvfuser {
+
+static std::mutex fusion_cache_lock;
+FusionCache* FusionCache::singleton_ = nullptr;
+
+FusionCacheEntry::FusionCacheEntry(RecordFunctor* rec, size_t _fusion_id)
+    : record(rec), record_hash_map(), fusion_id(_fusion_id), visits(0) {}
+
+bool FusionCacheEntry::isTerminal() const {
+  return (record.get()->recordType() == RecordType::End);
+}
+
+FusionCache* FusionCache::get(size_t max_fusions) {
+  std::lock_guard<std::mutex> guard(fusion_cache_lock);
+  if (singleton_ == nullptr) {
+    singleton_ = new FusionCache(max_fusions);
+  }
+  return singleton_;
+}
+
+size_t FusionCache::numFusions() const {
+  return fusions_.size();
+}
+
+void FusionCache::print(std::ostream& os) {
+  os << "Total Fusions: " << fusions_.size() << "\n";
+
+  // Does not make sense to print stats if the cache is disabled.
+  if (fusions_.size() > 0) {
+    os << "Cache Hits by Fusion Id:\n";
+    auto total_cache_hits = 0;
+    for (size_t i = 0; i < terminal_cache_entries_.size(); ++i) {
+      // The first visit is a miss!
+      auto visits = terminal_cache_entries_[i]->visits - 1;
+      total_cache_hits += visits;
+      os << "\t" << i << " -> " << visits << " hits\n";
+    }
+
+    auto hit_rate = static_cast<float>(total_cache_hits) /
+        static_cast<float>(fusion_cache_start_->visits) * 100.0;
+    os << "Cache Lookups: " << fusion_cache_start_->visits;
+    os << " Cache Hits: " << total_cache_hits;
+    os << " Hit Rate: " << hit_rate << "%\n";
+  }
+}
+
+FusionCache::FusionCache(size_t max_fusions)
+    : max_fusions_(max_fusions),
+      fusion_cache_start_(nullptr),
+      fusion_cache_ptr_(nullptr),
+      fusions_() {
+  RecordFunctor* start = new StartRecord();
+  fusion_cache_start_ = std::make_unique<FusionCacheEntry>(start);
+  fusion_cache_ptr_ = fusion_cache_start_.get();
+}
+
+c10::optional<FusionCacheEntry*> FusionCache::lookupFusionCacheEntry(
+    RecordFunctor* rec) const {
+  TORCH_CHECK(
+      !fusionCachePtr()->isTerminal(),
+      "There should be no children from a Terminal Cache Entry!");
+  TORCH_CHECK(rec, "Record is null!");
+  auto cache_entry = fusionCachePtr()->record_hash_map.find(rec);
+  if (cache_entry == std::end(fusionCachePtr()->record_hash_map)) {
+    return c10::nullopt;
+  } else {
+    return c10::optional<FusionCacheEntry*>(cache_entry->second.get());
+  }
+}
+
+c10::optional<size_t> FusionCache::createFusionCacheEntry(RecordFunctor* rec) {
+  c10::optional<size_t> result = c10::nullopt;
+  TORCH_CHECK(
+      !fusionCachePtr()->isTerminal(),
+      "Cannot create a cache entry from a terminal entry!");
+  TORCH_CHECK(rec, "Record is null!");
+
+  size_t fusion_id = 0;
+  if (rec->recordType() == RecordType::End) {
+    TORCH_CHECK(
+        (fusions_.size() + 1) <= max_fusions_,
+        "The number of fusions in nvfuser has exceeded ",
+        max_fusions_,
+        "fusions.  The max_fusions for the FusionCache might need to be ",
+        "increased if the max number is not being exceeded due to an error.");
+    fusions_.push_back(std::make_unique<Nvf::FusionExecutorCache>(
+        std::make_unique<Nvf::Fusion>()));
+    fusion_id = fusions_.size() - 1;
+    result = c10::optional<size_t>(fusion_id);
+  }
+
+  // Copying the record owned by the FusionDefinition that calls this function
+  // so the cache owns a copy when the FusionDefinition gets destroyed rather
+  // than managing a shared pointer that would  only share with
+  // FusionDefinition that creates a cache entry but not cache lookups
+  RecordFunctor* new_rec = rec->clone();
+  fusionCachePtr()->record_hash_map[new_rec] =
+      std::make_unique<FusionCacheEntry>(new_rec, fusion_id);
+  if (rec->recordType() == RecordType::End) {
+    terminal_cache_entries_.push_back(
+        fusionCachePtr()->record_hash_map[new_rec].get());
+  }
+  if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonFrontendDebug)) {
+    std::stringstream ss;
+    new_rec->print(ss);
+    std::cout << "\nFusionDefinition: Create new cache entry for: " << ss.str()
+              << "\n";
+  }
+  return result;
+}
+
+void FusionCache::resetFusionCachePtr() {
+  fusion_cache_ptr_ = fusion_cache_start_.get();
+  TORCH_CHECK(fusionCachePtr()->record->recordType() == RecordType::Start);
+  ++(fusionCachePtr()->visits);
+}
+
+void FusionCache::traverseFusionCache(RecordFunctor* rec) {
+  TORCH_CHECK(
+      !fusionCachePtr()->isTerminal(),
+      "Cannot traverse cache from a terminal entry!");
+  auto cache_entry = fusionCachePtr()->record_hash_map.find(rec);
+  TORCH_CHECK(
+      cache_entry != std::end(fusionCachePtr()->record_hash_map),
+      "Cache Entry for Cache Traverse is not found!");
+  TORCH_CHECK(cache_entry->second, "Record in Cache Entry is null!");
+  fusion_cache_ptr_ = cache_entry->second.get();
+  ++(fusionCachePtr()->visits);
+}
+
+FusionCacheEntry* FusionCache::fusionCachePtr() const {
+  TORCH_INTERNAL_ASSERT(
+      fusion_cache_ptr_ != nullptr,
+      "The fusion cache entry is unexpectedly null.");
+  return fusion_cache_ptr_;
+}
+
+} // namespace nvfuser
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h
new file mode 100644
index 0000000..30cc7fa
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h
@@ -0,0 +1,109 @@
+#pragma once
+#include <c10/macros/Export.h>
+
+#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
+
+#include <memory>
+
+//! nvFuser Fusion IR namespace abbreviation
+namespace Nvf = torch::jit::fuser::cuda;
+
+namespace nvfuser {
+
+struct RecordFunctor;
+
+//! \struct FusionCacheEntry
+//! \brief Is the container for a Node in the cache contained in the
+//! FusionCache that is organized as a prefix tree.
+
+struct TORCH_CUDA_CU_API FusionCacheEntry {
+  FusionCacheEntry(RecordFunctor* rec, size_t _fusion_id = 0);
+
+  // Queries whether the entry denotes a leaf node which also represents
+  // a the end of Fusion entry in the cache.
+  bool isTerminal() const;
+
+  //! An entry's primary data is the record it holds
+  std::unique_ptr<RecordFunctor> record;
+  //! A hash map of the children for the current node.
+  //! The hash map hashs a pointer to a RecordFunctor because
+  //! the hash function is virtual.
+  std::unordered_map<RecordFunctor*, std::unique_ptr<FusionCacheEntry>>
+      record_hash_map;
+  //! An index into FusionCache's vector of nvFuser object that holds an
+  //! unscheduled Fusion.  The id is only valid if the entry is terminal.
+  size_t fusion_id;
+  //! Count of times the Entry is traversed
+  size_t visits;
+};
+
+//! \class FusionCache
+//! \brief A singleton class used in the nvFuser python interface
+//! to manage the caching of fusions.
+//!
+//! The fusion cache implements a prefix tree of records in order to cache
+//! fusions.  A leaf of the tree with a terminal node contains an nvFuser
+//! Fusion IR container for a cached instance.
+//!
+//! \todo Add the ability to evict a fusion.  There is currently a max number
+//! of fusions that is checked to prevent a runaway case.
+
+class TORCH_CUDA_CU_API FusionCache {
+  //! The constructor is private given the FusionCache is only constructed
+  //! as a singleton.
+  FusionCache(size_t max_fusions);
+
+  //! Copy and Assignment of the FusionCache is not supported
+  FusionCache(const FusionCache&) = delete;
+  FusionCache& operator=(const FusionCache&) = delete;
+
+ public:
+  //! The next 2 pubic methods are the python interface methods
+
+  //! Gets a pointer to the singleton and creates a new one if necessary
+  static FusionCache* get(size_t max_fusions = 8192);
+  //! Number of fusions cached
+  size_t numFusions() const;
+  //! print cache stats
+  void print(std::ostream& os);
+
+  //! The rest of the public methods are only used in C++
+
+  //! Queries the current cache entry to see if a record matches one of its
+  //! children
+  c10::optional<FusionCacheEntry*> lookupFusionCacheEntry(
+      RecordFunctor* rec) const;
+  //! Creates a child node for the current cache entry and an optional
+  //! fusion_id is returned if the new entry is terminal
+  c10::optional<size_t> createFusionCacheEntry(RecordFunctor* rec);
+  //! Resets the current cache pointer to the top of the tree
+  void resetFusionCachePtr();
+  //! Traverses the cache from the current entry to the child associated
+  //! with the record given.
+  void traverseFusionCache(RecordFunctor* rec);
+
+  friend class FusionInterface;
+
+ private:
+  //! Returns the pointer to the current cache entry
+  FusionCacheEntry* fusionCachePtr() const;
+
+  //! The static pointer to the FusionCache
+  static FusionCache* singleton_;
+
+  //! The max allowed number of fusions in the cache
+  size_t max_fusions_;
+  //! The top of the prefix tree used to start a cache look up of a given
+  //! fusion definition.
+  std::unique_ptr<FusionCacheEntry> fusion_cache_start_;
+  //! A pointer to the current cache entry in a cache lookup of a fusion
+  //! definition.
+  FusionCacheEntry* fusion_cache_ptr_;
+  //! A vector of nvFuser Fusion IR fusions.
+  std::vector<std::unique_ptr<Nvf::FusionExecutorCache>> fusions_;
+  //! A vector of Terminal Cache Entries for Stats collection
+  std::vector<FusionCacheEntry*> terminal_cache_entries_;
+};
+
+} // namespace nvfuser
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp
index 4efdc21..cf467d9 100644
--- a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp
@@ -1,65 +1,186 @@
-#ifdef USE_CUDA
+#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h>
 #include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h>
-#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h>
+#include <torch/csrc/jit/codegen/cuda/utils.h>
+
+// Require namespace for perf scope instrumentation
+using namespace torch::jit::fuser::cuda::inst;
 
 namespace nvfuser {
 
-FusionDefinition::FusionDefinition(FusionOwner* fusion_owner)
-    : fusion_owner_(fusion_owner),
-      prev_fusion_(nullptr),
+const char* dtypeToPyString(Nvf::DataType t) {
+  switch (t) {
+    case Nvf::DataType::Bool:
+      return "DataType.Bool";
+    case Nvf::DataType::Double:
+      return "DataType.Double";
+    case Nvf::DataType::Float:
+      return "DataType.Float";
+    case Nvf::DataType::Half:
+      return "DataType.Half";
+    case Nvf::DataType::BFloat16:
+      return "DataType.Bfloat16";
+    case Nvf::DataType::Int:
+      return "DataType.Int";
+    case Nvf::DataType::Int32:
+      return "DataType.Int32";
+    case Nvf::DataType::ComplexFloat:
+      return "DataType.ComplexFloat";
+    case Nvf::DataType::ComplexDouble:
+      return "DataType.ComplexDouble";
+    case Nvf::DataType::Null:
+      return "DataType.Null";
+    default:
+      break;
+  }
+  TORCH_INTERNAL_ASSERT(false, "No string found for data type.");
+  return nullptr;
+}
+
+FusionDefinition::FusionDefinition(FusionInterface* fusion, size_t max_length)
+    : max_length_(max_length),
+      fusion_(fusion),
+      fusion_cache_(FusionCache::get()),
+      end_record_(new EndRecord()),
       recording_(),
       recording_state_(),
       fusion_state_(),
       ops(this) {}
 
-FusionDefinition* FusionDefinition::enter() {
-  prev_fusion_ = FusionGuard::getCurFusion();
-  FusionGuard::setCurFusion(fusionPtr());
-  return this;
-}
-void FusionDefinition::exit() {
+void FusionDefinition::buildFusionIr() {
+  FUSER_PERF_SCOPE("FusionDefinition::buildFusionIr");
+  auto fusion_guard = fusionInterfacePtr()->guard();
   fusion_state_.resize(recording_state_.size(), nullptr);
   for (auto& record : recording_) {
     auto functor = record.get();
     (*functor)(*this);
   }
-
-  FusionGuard::setCurFusion(prev_fusion_);
-  prev_fusion_ = nullptr;
 }
 
-Scalar* FusionDefinition::defineScalar() {
-  Scalar* out = new nvfuser::Scalar(recording_state_.size());
-  recording_state_.emplace_back(out);
+FusionCache* FusionDefinition::fusionCachePtr() const {
+  TORCH_INTERNAL_ASSERT(
+      fusion_cache_ != nullptr, "FusionCache pointer is null!");
+  return fusion_cache_;
+}
+
+FusionInterface* FusionDefinition::fusionInterfacePtr() const {
+  TORCH_INTERNAL_ASSERT(fusion_ != nullptr, "FusionInterface pointer is null!");
+  return fusion_;
+}
+
+FusionDefinition* FusionDefinition::enter() {
+  TORCH_CHECK(max_length_ > 0, "Can't make a FusionDefinition with 0 records!");
+  TORCH_CHECK(
+      !fusionInterfacePtr()->defined(), "Fusion Interface is already defined!");
+  fusionCachePtr()->resetFusionCachePtr();
+  return this;
+}
+
+void FusionDefinition::exit() {
+  FUSER_PERF_SCOPE("FusionDefinition::exit");
+  auto cache_entry =
+      fusionCachePtr()->lookupFusionCacheEntry(end_record_.get());
+  if (!cache_entry.has_value()) {
+    if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonFrontendDebug)) {
+      std::cout << "\nFusionDefinition: Terminal Node not found.\n";
+    }
+    auto fusion_id =
+        fusionCachePtr()->createFusionCacheEntry(end_record_.get());
+    TORCH_CHECK(fusion_id.has_value(), "Invalid fusion id!");
+    fusionInterfacePtr()->define(fusion_id.value());
+    fusionCachePtr()->traverseFusionCache(end_record_.get());
+
+    if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonDefinition)) {
+      print(std::cout);
+    }
+
+    buildFusionIr();
+
+    if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::FusionIrPresched)) {
+      fusionInterfacePtr()->print();
+    }
+  } else {
+    if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonFrontendDebug)) {
+      std::cout << "\nFusionDefinition: Terminal Node found!\n";
+    }
+    fusionInterfacePtr()->define(cache_entry.value()->fusion_id);
+    fusionCachePtr()->traverseFusionCache(end_record_.get());
+  }
+}
+
+void FusionDefinition::print(std::ostream& os) const {
+  os << "\ndef nvfuser_fusion_id" << fusion_->id();
+  os << "(fd : FusionDefinition) -> None :\n";
+  os << std::dec;
+  for (auto& rec : recording_) {
+    os << "    ";
+    rec->print(os);
+    os << "\n";
+  }
+  os << "\n";
+}
+
+Scalar FusionDefinition::defineScalar() {
+  FUSER_PERF_SCOPE("FusionDefinition::defineScalar");
+  Scalar out(recording_state_.size());
+  recording_state_.emplace_back(out(), StateType::Scalar);
   return out;
 }
-Tensor* FusionDefinition::defineTensor() {
-  Tensor* out = new nvfuser::Tensor(recording_state_.size());
-  recording_state_.emplace_back(out);
+
+Tensor FusionDefinition::defineTensor() {
+  FUSER_PERF_SCOPE("FusionDefinition::defineTensor");
+  Tensor out(recording_state_.size());
+  recording_state_.emplace_back(out(), StateType::Tensor);
   return out;
 }
+
 void FusionDefinition::defineRecord(RecordFunctor* record) {
+  FUSER_PERF_SCOPE("FusionDefinition::defineRecord");
+  TORCH_CHECK(
+      (recording_.size() + 1) <= max_length_,
+      "The fusion definition has exceeded ",
+      max_length_,
+      "operations.  The max_length for FusionDefintion's might need to be ",
+      "increased if the definition is created as expected.");
   recording_.emplace_back(record);
+  auto cache_entry =
+      fusionCachePtr()->lookupFusionCacheEntry(recording_.back().get());
+  // If the Record is found in the cache, the FusionDefinition and the Cache
+  // will not share Record given the Record had to be created in order to
+  // match it but it also already existed in the cache.
+  if (cache_entry.has_value()) {
+    if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonFrontendDebug)) {
+      std::cout << "\nFusionDefinition: Record (hash: 0x" << std::hex
+                << record->hash() << ") hit in Fusion Cache.\n";
+    }
+    // The FusionDefinition and the Cache will share the Record
+  } else {
+    if (Nvf::isDebugDumpEnabled(Nvf::DebugDumpOption::PythonFrontendDebug)) {
+      std::cout << "\nFusionDefinition: Record (hash: 0x" << std::hex
+                << record->hash() << ") missed in Fusion Cache.\n";
+    }
+    fusionCachePtr()->createFusionCacheEntry(recording_.back().get());
+  }
+  fusionCachePtr()->traverseFusionCache(recording_.back().get());
 }
 
-void FusionDefinition::addInput(NvfVal* input) {
-  fusionPtr()->addInput(input);
+void FusionDefinition::addInput(Nvf::Val* input) {
+  fusionInterfacePtr()->addInput(input);
 }
-void FusionDefinition::addOutput(NvfVal* output) {
-  fusionPtr()->addOutput(output);
+void FusionDefinition::addOutput(Nvf::Val* output) {
+  fusionInterfacePtr()->addOutput(output);
 }
 
-NvfVal* FusionDefinition::getFusionState(size_t index) const {
+Nvf::Val* FusionDefinition::getFusionState(size_t index) const {
   return fusion_state_.at(index);
 }
-void FusionDefinition::setFusionState(size_t index, NvfVal* val) {
+void FusionDefinition::setFusionState(size_t index, Nvf::Val* val) {
   fusion_state_.at(index) = val;
 }
 
-Fusion* FusionDefinition::fusionPtr() {
-  return fusion_owner_->fusionPtr();
+State FusionDefinition::recordingState(size_t index) const {
+  return recording_state_.at(index);
 }
 
 } // namespace nvfuser
-
-#endif // USE_CUDA
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h
index a5aca2f..6872381 100644
--- a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h
@@ -1,21 +1,24 @@
 #pragma once
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h>
+#include <c10/macros/Export.h>
 
-//! nvFuser Fusion IR Types
-using NvfDataType = torch::jit::fuser::cuda::DataType;
-using NvfFusion = torch::jit::fuser::cuda::Fusion;
-using NvfTensorView = torch::jit::fuser::cuda::TensorView;
-using NvfVal = torch::jit::fuser::cuda::Val;
+#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
+
+//! nvFuser Fusion IR namespace abbreviation
+namespace Nvf = torch::jit::fuser::cuda;
 
 namespace nvfuser {
 
+class FusionCache;
+class FusionInterface;
 struct RecordFunctor;
 
-//! The State, child classes Tensor and Scalar, and the StateType enum
-//! are used to define state objects to encapsulate the recording of state
-//! in the FusionDefinition.
+//! This is helper function used to print a python formated
+//! Fusion IR DataType when printing a fusion definition.
+
+TORCH_CUDA_CU_API const char* dtypeToPyString(Nvf::DataType t);
+
+//! The State and the StateType enum are used to define state objects to
+//! encapsulate the recording of state in the FusionDefinition.
 
 enum class StateType {
   Tensor,
@@ -24,15 +27,15 @@
 };
 
 struct State {
-  State(StateType _stype, size_t _index) : stype(_stype), index(_index) {}
+  State(size_t _index, StateType _stype) : index(_index), stype(_stype) {}
 
-  //! StateType is either: Tensor or Scalar
-  StateType stype;
   //! A unique index to identifiy each recorded state item.
   size_t index;
+  //! StateType is either: Tensor or Scalar
+  StateType stype;
 };
 
-//! The child classes are used to define separate function signtures in
+//! The Tensor and Scalar classes are used to define separate function signtures
 //! in the FusionDefintion to identify the appropriate Operator function.
 //!
 //! Example:
@@ -40,12 +43,26 @@
 //!   add(Tensor* arg1, Tensor* arg2) -> Tensor*
 //!   add(Tensor* arg1, Scalar* arg2) -> Tensor*
 //!   add(Scalar* arg1, Scalar* arg2) -> Scalar*
-struct Tensor : State {
-  Tensor(size_t _index) : State(StateType::Tensor, _index) {}
+struct Tensor {
+  Tensor(size_t _index) : index(_index) {}
+
+  size_t operator()() const {
+    return index;
+  }
+
+  //! A unique index to identifiy each recorded state item.
+  size_t index;
 };
 
-struct Scalar : State {
-  Scalar(size_t _index) : State(StateType::Scalar, _index) {}
+struct Scalar {
+  Scalar(size_t _index) : index(_index) {}
+
+  size_t operator()() const {
+    return index;
+  }
+
+  //! A unique index to identifiy each recorded state item.
+  size_t index;
 };
 
 //! FusionDefinition defines the C++ side of a Python Context manager to
@@ -56,17 +73,14 @@
 //! in a cache and the recorded records are used to build an nvFuser Fusion
 //! object if the definition missed in the cache.
 //!
-//! \todo Need to implement the cache portion. Currently, the Fusion object
-//! is always built.
-//!
 //! The nested Operators class was designed to allow the user to query all the
 //! available Operators in the FusionDefinition via python help.
 //!
 //! Example:
 //!   help(FusionDefinition.Operators)
-class FusionDefinition {
+class TORCH_CUDA_CU_API FusionDefinition {
  public:
-  FusionDefinition(FusionOwner* fusion_owner);
+  FusionDefinition(FusionInterface* fusion, size_t max_length = 256);
 
   // The copy/move/assign constructors/operators are being removed
   // because it is not possible to copy the fusion_recording data member
@@ -81,46 +95,60 @@
   FusionDefinition* enter();
   //! Exit Python Context Manager -- Triggers cache lookup
   void exit();
+  //! Prints a python function representing the definition
+  void print(std::ostream& os) const;
 
   //! These methods are used to record the FusionDefinition for cache lookup
 
   //! Defines a Scalar State Record
-  Scalar* defineScalar();
+  Scalar defineScalar();
   //! Defines a Tensor State Record
-  Tensor* defineTensor();
+  Tensor defineTensor();
   //! Defines a Record that records the operation required to
   //! build the corresponding Fusion IR operation on cache miss.
   void defineRecord(RecordFunctor* record);
-
-  //! These methods are used to replay the operations for building the
-  //! nvFuser Fusion IR on a cache miss.
-
   //! Adds a Tensor/Scalar input to the Fusion object
-  void addInput(NvfVal* input);
+  void addInput(Nvf::Val* input);
   //! Adds a Tensor/Scalar output to the Fusion object
-  void addOutput(NvfVal* output);
+  void addOutput(Nvf::Val* output);
   //! Gets a Fusion IR Tensor/Scalar object
-  NvfVal* getFusionState(size_t index) const;
+  Nvf::Val* getFusionState(size_t index) const;
   //! Sets a Fusion IR Tensor/Scalar object
-  void setFusionState(size_t index, NvfVal* val);
-
-  //! A pointer to the nvFuser Fusion IR Oject
-  NvfFusion* fusionPtr();
+  void setFusionState(size_t index, Nvf::Val* val);
+  //! Gets a Record State object
+  State recordingState(size_t index) const;
 
  private:
-  // \todo These items will be replaced by a FusionManager instead of a cache
-  // for an individual fusion object
-  FusionOwner* fusion_owner_;
-  NvfFusion* prev_fusion_;
+  //! Builds an nvFuser Fusion IR object upon exit of a FusionDefintion
+  //! when a cache lookup fails.
+  void buildFusionIr();
+  //! Returns the FusionCache Ptr that holds the cache of Fusions
+  FusionCache* fusionCachePtr() const;
+  //! Returns the FusionInterface Ptr that represents the corresponding
+  //! Fusion IR object.
+  FusionInterface* fusionInterfacePtr() const;
+
+  //! Holds the defined maximum length of a FusionDefinition in order to
+  //! prevent a run away error. The user should feel free to increase this
+  //! number as appropriate.
+  size_t max_length_;
+
+  //! A pointer to an interface for an nvFusion Fusion IR object.
+  FusionInterface* fusion_;
+  //! A pointer to the FusionCache.
+  FusionCache* fusion_cache_;
+
+  //! Holds an End Record
+  std::unique_ptr<RecordFunctor> end_record_;
 
   //! A vector of record operations in the FusionDefintion
   std::vector<std::unique_ptr<RecordFunctor>> recording_;
-  //! A vector of state (Tensor/Scalar) recorded in the FusionDefinition
-  std::vector<std::unique_ptr<State>> recording_state_;
+  //! A vector of state recorded in the FusionDefinition
+  std::vector<State> recording_state_;
 
   //! A vector of nvFuser Fusion IR TensorViews/Vals for building the Fusion
   //! IR graph.
-  std::vector<NvfVal*> fusion_state_;
+  std::vector<Nvf::Val*> fusion_state_;
 
  public:
   //! The Operators are not directly defined in this header.  They are defined
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.cpp
new file mode 100644
index 0000000..d1d33dd
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.cpp
@@ -0,0 +1,60 @@
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h>
+
+namespace nvfuser {
+
+FusionInterface::FusionInterface() : fusion_id_(c10::nullopt) {}
+FusionInterface::FusionInterface(size_t fusion_id)
+    : fusion_id_(c10::optional<size_t>(fusion_id)) {}
+
+void FusionInterface::define(size_t fusion_id) {
+  auto fc = FusionCache::get();
+  TORCH_CHECK(fusion_id < fc->fusions_.size(), "Invalid fusion id!");
+  fusion_id_ = c10::optional<size_t>(fusion_id);
+}
+
+bool FusionInterface::defined() const {
+  return fusion_id_.has_value();
+}
+
+size_t FusionInterface::id() const {
+  TORCH_CHECK(defined(), "Invalid fusion id!");
+  return fusion_id_.value();
+}
+
+void FusionInterface::addInput(Nvf::Val* input) const {
+  fusionPtr()->addInput(input);
+}
+
+void FusionInterface::addOutput(Nvf::Val* output) const {
+  fusionPtr()->addOutput(output);
+}
+
+std::vector<at::Tensor> FusionInterface::execute(
+    const at::ArrayRef<c10::IValue>& inputs) const {
+  return fusionExecutorCachePtr()->runFusionWithInputs(inputs);
+}
+
+Nvf::FusionGuard FusionInterface::guard() const {
+  return Nvf::FusionGuard(fusionPtr());
+}
+
+void FusionInterface::print() const {
+  fusionExecutorCachePtr()->printFusion();
+}
+
+Nvf::FusionExecutorCache* FusionInterface::fusionExecutorCachePtr() const {
+  auto fc = FusionCache::get();
+  TORCH_CHECK(defined(), "Invalid fusion id!");
+  TORCH_CHECK(
+      fc->fusions_.at(fusion_id_.value()), "FusionExecutorCache Ptr is Null!");
+  return fc->fusions_.at(fusion_id_.value()).get();
+}
+
+Nvf::Fusion* FusionInterface::fusionPtr() const {
+  auto fusion_ptr = fusionExecutorCachePtr()->fusion();
+  TORCH_CHECK(fusion_ptr != nullptr, "Fusion IR pointer is null!");
+  return fusion_ptr;
+}
+
+} // namespace nvfuser
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h
new file mode 100644
index 0000000..60d55f1
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h
@@ -0,0 +1,72 @@
+#pragma once
+#include <c10/macros/Export.h>
+
+#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
+
+//! nvFuser Fusion IR namespace abbreviation
+namespace Nvf = torch::jit::fuser::cuda;
+
+namespace nvfuser {
+
+//! \class FusionInterface
+//! \brief Implements an Interface that represents an nvFuser IR object in
+//! in python.
+//!
+//! Example 1 - Define fusion:
+//!
+//!   fs = Fusion()
+//!   with FusionDefinition(fs) as fd :
+//!       t0 = fd.define_tensor(3)
+//!       s1 = fd.define_constant(3.)
+//!       t2 = fd.ops.add(t0, s1)
+//!       fd.add_output(t2)
+//!
+//!   input = torch.ones(2, 4, 8, device='cuda')
+//!   for _ in range(5) :
+//!      outputs = fs.execute([input])
+//!
+//! Example 2 - Use cached fusion, directly, based on id:
+//!
+//!   fs = Fusion(fusion_id)
+//!
+//!   input = torch.ones(2, 4, 8, device='cuda')
+//!   for _ in range(5) :
+//!      outputs = fs.execute([input])
+
+class TORCH_CUDA_CU_API FusionInterface {
+ public:
+  //! Pybind11 cannot bind to c10::optional and Pytorch is compiled with C++14.
+  //! Therefore, I am adding two constructors, instead.
+  FusionInterface();
+  FusionInterface(size_t fusion_id);
+
+  //! Define which Fusion IR object the interface represents
+  void define(size_t fusion_id);
+  //! Query whether the Fusion IR is defined
+  bool defined() const;
+  //! Return fusion id of this Fusion
+  size_t id() const;
+
+  //! Adds an input to the represented Fusion IR.
+  void addInput(Nvf::Val* input) const;
+  //! Adds an Output to the represented Fusion IR.
+  void addOutput(Nvf::Val* output) const;
+  //! Executes a fusion if the current cache pointer points at a terminal node
+  std::vector<at::Tensor> execute(
+      const at::ArrayRef<c10::IValue>& inputs) const;
+  //! Activates a guard around the represented Fusion IR.
+  Nvf::FusionGuard guard() const;
+  //! Prints the represented nvFuser IR
+  void print() const;
+
+ private:
+  //! Provides a pointer to the FusionExecutorCache that maps the current
+  //! unscheduled Fusion IRs to scheduled Fusion IRs for execution.
+  Nvf::FusionExecutorCache* fusionExecutorCachePtr() const;
+  //! Points to the nvFuser Fusion IR object
+  Nvf::Fusion* fusionPtr() const;
+
+  c10::optional<size_t> fusion_id_;
+};
+
+} // namespace nvfuser
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h
deleted file mode 100644
index dce8cc4..0000000
--- a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h
+++ /dev/null
@@ -1,36 +0,0 @@
-
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-
-using namespace torch::jit::fuser::cuda;
-
-namespace nvfuser {
-
-class FusionOwner {
- public:
-  FusionOwner() : executor_cache_(std::make_unique<Fusion>()) {}
-
-  // Non-copyable
-  FusionOwner(const FusionOwner&) = delete;
-  FusionOwner& operator=(const FusionOwner&) = delete;
-
-  std::vector<at::Tensor> execute(const at::ArrayRef<c10::IValue>& inputs) {
-    return executor_cache_.runFusionWithInputs(inputs);
-  }
-  Fusion* fusionPtr() {
-    return executor_cache_.fusion();
-  }
-
-  void printIr() {
-    executor_cache_.printFusion();
-  }
-  void printKernel() {
-    executor_cache_.fusion()->printKernel();
-  }
-
- private:
-  FusionExecutorCache executor_cache_;
-};
-
-} // namespace nvfuser
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h
index 4616bd1..e5fb374 100644
--- a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h
@@ -4,31 +4,166 @@
 #include <torch/csrc/jit/codegen/cuda/ops/alias.h>
 #include <torch/csrc/jit/codegen/cuda/ops/normalization.h>
 #include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h>
+#include <torch/csrc/jit/codegen/cuda/utils.h>
 
 namespace nvfuser {
 
+//! This enum it to give a Record Type for record hashing given that the
+//! record type is otherwise determined via the success of dynamic casting.
+//! This means that templated types are not specifically enumerated for
+//! each set of template arguments.
+enum class RecordType {
+  Base = 0,
+  Op,
+  BroadcastOp,
+  CastOp,
+  Constant,
+  End,
+  Tensor,
+  Output,
+  ReductionOp,
+  Scalar,
+  SqueezeOp,
+  Start,
+  VarianceOp,
+  VarianceMeanOp,
+};
+
 //! RecordFunctor is the base class record for operations recorded by
 //! the FusionDefinition.  It is, in essence, a node in the graph with
-//! input edges, args, and outputs edges outputs that where the stored
+//! input edges, args, and outputs edges outputs where the stored
 //! values are indices into the recorded state.
 //!
-//! The virual functor is the operators that is replayed on a cache
-//! to build the appropriate part of the nvFuser Fusion IR for a given
-//! record.
+//! The virual functor operator is executed on a cache miss to build the
+//! appropriate part of the nvFuser Fusion IR for a given record.
+//!
+//! The hash and equality operators are used to facilitate the hashing of
+//! RecordFunctors in a hash map given those operators need to be
+//! specified for custom objects.
+//!
+//! The print function is used to print the given Record as a statement
+//! in a python formated function.
 
 struct RecordFunctor {
-  RecordFunctor(std::vector<size_t> _args, std::vector<size_t> _outputs)
-      : args(std::move(_args)), outputs(std::move(_outputs)) {}
+  RecordFunctor(
+      std::vector<State> _args,
+      std::vector<State> _outputs,
+      std::string _name,
+      RecordType _record_type)
+      : args_(std::move(_args)),
+        outputs_(std::move(_outputs)),
+        name_(std::move(_name)),
+        record_type_(_record_type) {}
   virtual ~RecordFunctor() = default;
+  //! Allows for copying of Child Class objects with RecordFunctor pointers.
+  virtual RecordFunctor* clone() = 0;
+
+  //! The base class is placing the type, outputs, and args hashed as follows:
+  //! | 63 - 56 | 55 - 48 | 47 ----------- 32 | 32 ------------------------  0 |
+  //! | Type    | Outputs | Args              | Child Class Specified          |
+  virtual size_t hash() const {
+    size_t arg_hash = 0;
+    for (auto arg : args_) {
+      arg_hash ^= ((arg.index << 1) ^ static_cast<size_t>(arg.stype));
+    }
+    size_t output_hash = 0;
+    for (auto output : outputs_) {
+      output_hash ^= ((output.index << 1) ^ static_cast<size_t>(output.stype));
+    }
+    return ((static_cast<size_t>(record_type_) & 0xff) << 56) |
+        ((output_hash & 0xff) << 48) | ((arg_hash & 0xffff) << 32);
+  }
+
+  //! The base virtual equality operator is defined so all child
+  //! classes can utilize the check for the same args and outputs.
+  virtual bool operator==(const RecordFunctor& other) const {
+    auto result = (record_type_ == other.record_type_);
+    result = result && (args_.size() == other.args_.size()) &&
+        (outputs_.size() == other.outputs_.size());
+    if (result) {
+      for (size_t i = 0; i < args_.size(); ++i) {
+        if ((args_[i].index != other.args_[i].index) ||
+            (args_[i].stype != other.args_[i].stype)) {
+          result = false;
+          break;
+        }
+      }
+    }
+    if (result) {
+      for (size_t i = 0; i < outputs_.size(); ++i) {
+        if ((outputs_[i].index != other.outputs_[i].index) ||
+            (outputs_[i].stype != other.outputs_[i].stype)) {
+          result = false;
+          break;
+        }
+      }
+    }
+    return result;
+  }
 
   //! Abstraction for an operation to build this record's nvFuser Fusion IR
   //! piece if the recording has a cache miss.
   virtual void operator()(FusionDefinition& fd) = 0;
 
+  //! The base print function when printing Record for a given FusionDefinition
+  //! in python formated code.
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    bool first_output = true;
+    for (auto& output : outputs_) {
+      if (first_output) {
+        first_output = false;
+      } else {
+        os << ", ";
+      }
+      if (output.stype == StateType::Scalar) {
+        os << "S";
+      } else if (output.stype == StateType::Tensor) {
+        os << "T";
+      } else {
+        TORCH_INTERNAL_ASSERT(false, "Unsupported StateType");
+      }
+      os << output.index;
+    }
+    if (outputs_.size() > 0) {
+      os << " = "
+         << "fd." << name_ << "(";
+    } else {
+      os << "fd." << name_ << "(";
+    }
+    bool first_arg = true;
+    for (auto& arg : args_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      if (arg.stype == StateType::Scalar) {
+        os << "S";
+      } else if (arg.stype == StateType::Tensor) {
+        os << "T";
+      } else {
+        TORCH_INTERNAL_ASSERT(false, "Unsupported StateType");
+      }
+      os << arg.index;
+    }
+    if (close_function) {
+      os << ")";
+    }
+  }
+
+  RecordType recordType() const {
+    return record_type_;
+  }
+
+ protected:
   //! Inputs that are indices into the FusionDefinition's Recorded State.
-  std::vector<size_t> args;
+  std::vector<State> args_;
   //! Outputs that are indices into the FusionDefinition's Recorded State.
-  std::vector<size_t> outputs;
+  std::vector<State> outputs_;
+  //! Record Name
+  std::string name_;
+  //! Record Type of child class used for hashing
+  RecordType record_type_;
 };
 
 //! The OpRecord RecordFunctor is the most widely used child class because
@@ -43,12 +178,65 @@
 template <class OutType, class... ArgTypes>
 struct OpRecord : RecordFunctor {
   OpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
+      std::vector<State> _args,
+      std::vector<State> _outputs,
+      std::string _name,
       std::function<OutType(ArgTypes...)> fusion_op)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
+      : RecordFunctor(
+            std::move(_args),
+            std::move(_outputs),
+            _name,
+            RecordType::Op),
         fusion_op_(fusion_op) {}
   virtual ~OpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new OpRecord(*this);
+  }
+
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 -------------------------------------  0 |
+  //! | Arith Function Sigs hash code               |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    return result | (fusion_op_.target_type().hash_code() & 0xffffffff);
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    // A succesfull cast indicates a RecordFunctor of the same child class
+    if (auto child_ptr = dynamic_cast<const OpRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      if (result) {
+        // Match the nvFuser arith function types
+        result = result &&
+            (fusion_op_.target_type() == child_ptr->fusion_op_.target_type());
+        if (Nvf::isDebugDumpEnabled(
+                Nvf::DebugDumpOption::PythonFrontendDebug)) {
+          std::cout << "\nOpRecord: " << name_ << " Target Type [self: 0x"
+                    << fusion_op_.target_type().name() << "] [other: 0x"
+                    << child_ptr->fusion_op_.target_type().name() << "] ";
+        }
+        // Match the nvFuser arith function pointers
+        // IMPORTANT! you need to dereference the target pointer in order
+        // to match the function
+        result = result &&
+            (*fusion_op_.template target<OutType (*)(ArgTypes...)>() ==
+             *child_ptr->fusion_op_
+                  .template target<OutType (*)(ArgTypes...)>());
+        if (Nvf::isDebugDumpEnabled(
+                Nvf::DebugDumpOption::PythonFrontendDebug)) {
+          std::cout
+              << "Target  Ptr [self: 0x" << std::hex
+              << (size_t)*fusion_op_.template target<OutType (*)(ArgTypes...)>()
+              << "] [other: 0x" << std::hex
+              << (size_t)*child_ptr->fusion_op_
+                     .template target<OutType (*)(ArgTypes...)>()
+              << "]\n";
+        }
+      }
+    }
+    return result;
+  }
 
   //! The variadic set of indices for the number of args for this op are
   //! deduced by providing the index_sequence as a parameter.  Similarly,
@@ -58,9 +246,9 @@
   //! to a Fusion IR TensorView or leave it as a Fusion IR Val (Scalar).
   //!
   //! A deduced binary op could look like:
-  //!   OutType opFunc<std::tuple<NvfTensor*, NvfTensor*>, 0, 1>
+  //!   OutType opFunc<std::tuple<TensorView*, TensorView*>, 0, 1>
   //! A deduced ternary op could look like:
-  //!   OutTupe opFunc<std::tuple<NvfTensor*, NvfVal*, NvfVal*>, 0, 1, 2>
+  //!   OutTupe opFunc<std::tuple<TensorView*, Val*, Val*>, 0, 1, 2>
   template <class TupleType, std::size_t... Is>
   OutType opFunc(
       FusionDefinition& fd,
@@ -68,17 +256,17 @@
       std::index_sequence<Is...>) {
     return fusion_op_(
         dynamic_cast<typename std::tuple_element<Is, TupleType>::type>(
-            fd.getFusionState(args.at(Is)))...);
+            fd.getFusionState(args_.at(Is).index))...);
   }
 
-  void operator()(FusionDefinition& fd) final {
+  virtual void operator()(FusionDefinition& fd) final {
     using arg_tuple_t = std::tuple<ArgTypes...>;
     auto indices =
         std::make_index_sequence<std::tuple_size<arg_tuple_t>::value>();
     // The tuple variable is never populated, it is passed for its type.
     arg_tuple_t inputs;
     auto output = opFunc(fd, inputs, indices);
-    fd.setFusionState(outputs.at(0), output);
+    fd.setFusionState(outputs_.at(0).index, output);
   }
 
  private:
@@ -88,21 +276,82 @@
 
 struct SqueezeOpRecord : RecordFunctor {
   SqueezeOpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
+      std::vector<State> _args,
+      std::vector<State> _outputs,
       std::vector<int64_t>& original_shape,
       int64_t dim)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
+      : RecordFunctor(
+            std::move(_args),
+            std::move(_outputs),
+            "squeeze",
+            RecordType::SqueezeOp),
         original_shape_(std::move(original_shape)),
         dim_(dim) {}
   virtual ~SqueezeOpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new SqueezeOpRecord(*this);
+  }
+
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 -------------- 16 | 15 --------------  0 |
+  //! | Squeeze Dim hash     | original_shape hash  |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    size_t original_shape_hash = 0;
+    for (auto shape : original_shape_) {
+      original_shape_hash ^= static_cast<size_t>(shape);
+    }
+    size_t squeeze_dim_hash = static_cast<size_t>(dim_);
+    squeeze_dim_hash = (squeeze_dim_hash & 0xffff) << 16;
+    return result | squeeze_dim_hash | (original_shape_hash & 0xffff);
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const SqueezeOpRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      if (result) {
+        result = (original_shape_.size() == child_ptr->original_shape_.size());
+        if (result) {
+          result = (dim_ == child_ptr->dim_);
+        }
+        if (result) {
+          for (size_t i = 0; i < original_shape_.size(); ++i) {
+            if (original_shape_[i] != child_ptr->original_shape_[i]) {
+              result = false;
+              break;
+            }
+          }
+        }
+      }
+    }
+    return result;
+  }
 
   void operator()(FusionDefinition& fd) final {
-    auto arg = fd.getFusionState(args.at(0))->template as<TensorView>();
+    auto arg =
+        fd.getFusionState(args_.at(0).index)->template as<Nvf::TensorView>();
+    auto output = Nvf::squeeze(arg, original_shape_, dim_);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
 
-    auto output = torch::jit::fuser::cuda::squeeze(arg, original_shape_, dim_);
-
-    fd.setFusionState(outputs.at(0), output);
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    os << ", original_shape=[";
+    bool first_arg = true;
+    for (auto shape : original_shape_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      os << shape;
+    }
+    os << "]";
+    os << ", dim=" << dim_;
+    if (close_function) {
+      os << ")";
+    }
   }
 
  private:
@@ -116,17 +365,72 @@
 
 struct BroadcastOpRecord : RecordFunctor {
   BroadcastOpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
+      std::vector<State> _args,
+      std::vector<State> _outputs,
+      std::string _name,
       std::vector<int64_t>& output_shape,
       std::vector<int64_t>& broadcast_dims)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
+      : RecordFunctor(
+            std::move(_args),
+            std::move(_outputs),
+            _name,
+            RecordType::BroadcastOp),
         output_shape_(std::move(output_shape)),
         broadcast_dims_(std::move(broadcast_dims)) {}
   virtual ~BroadcastOpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new BroadcastOpRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    auto arg = fd.getFusionState(args.at(0))->template as<TensorView>();
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 -------------- 16 | 15 --------------  0 |
+  //! | broadcast_dims hash  | output_shape hash    |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    size_t output_shape_hash = 0;
+    for (auto shape : output_shape_) {
+      output_shape_hash ^= static_cast<size_t>(shape);
+    }
+    size_t broadcast_dims_hash = 0;
+    for (auto dim : broadcast_dims_) {
+      broadcast_dims_hash |= 1 << ((output_shape_.size() - 1) - dim);
+    }
+    broadcast_dims_hash = (broadcast_dims_hash & 0xffff) << 16;
+    return result | broadcast_dims_hash | (output_shape_hash & 0xffff);
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const BroadcastOpRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      if (result) {
+        result =
+            ((output_shape_.size() == child_ptr->output_shape_.size()) &&
+             (broadcast_dims_.size() == child_ptr->broadcast_dims_.size()));
+        if (result) {
+          for (size_t i = 0; i < output_shape_.size(); ++i) {
+            if (output_shape_[i] != child_ptr->output_shape_[i]) {
+              result = false;
+              break;
+            }
+          }
+        }
+        if (result) {
+          for (size_t i = 0; i < broadcast_dims_.size(); ++i) {
+            if (broadcast_dims_[i] != child_ptr->broadcast_dims_[i]) {
+              result = false;
+              break;
+            }
+          }
+        }
+      }
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    auto arg =
+        fd.getFusionState(args_.at(0).index)->template as<Nvf::TensorView>();
 
     const auto& arg_domains_nr = arg->domain()->noReductions();
     const auto arg_ndims = arg_domains_nr.size();
@@ -168,18 +472,48 @@
           output_shape_[idx] != -1) {
         // TODO: this would be tricky to handle on dynamic shapes, we'll
         // need to pass-in a symbol instead somehow.
-        output_shape_on_bcast[idx] = IrBuilder::create<Int>(output_shape_[idx]);
+        output_shape_on_bcast[idx] =
+            Nvf::IrBuilder::create<Nvf::Int>(output_shape_[idx]);
         has_expand = true;
       } else {
-        output_shape_on_bcast[idx] = IrBuilder::create<Int>(-1);
+        output_shape_on_bcast[idx] = Nvf::IrBuilder::create<Nvf::Int>(-1);
       }
     }
 
-    auto output = torch::jit::fuser::cuda::broadcast(arg, is_broadcast_dim);
+    auto output = Nvf::broadcast(arg, is_broadcast_dim);
     if (has_expand) {
-      output = torch::jit::fuser::cuda::expand(output, output_shape_on_bcast);
+      output = Nvf::expand(output, output_shape_on_bcast);
     }
-    fd.setFusionState(outputs.at(0), output);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
+
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    os << ", output_shape=[";
+    bool first_arg = true;
+    for (auto shape : output_shape_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      os << shape;
+    }
+    os << "]";
+    os << ", broadcast_dims=[";
+    first_arg = true;
+    for (auto dim : broadcast_dims_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      os << dim;
+    }
+    os << "]";
+    if (close_function) {
+      os << ")";
+    }
   }
 
  private:
@@ -194,39 +528,140 @@
 template <class OutType, class ArgType>
 struct CastOpRecord : RecordFunctor {
   CastOpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
-      std::function<OutType(NvfDataType, ArgType)> fusion_op,
-      NvfDataType dtype)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
+      std::vector<State> _args,
+      std::vector<State> _outputs,
+      std::string _name,
+      std::function<OutType(Nvf::DataType, ArgType)> fusion_op,
+      Nvf::DataType dtype)
+      : RecordFunctor(
+            std::move(_args),
+            std::move(_outputs),
+            _name,
+            RecordType::CastOp),
         fusion_op_(fusion_op),
         dtype_(dtype) {}
   virtual ~CastOpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new CastOpRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    auto arg = dynamic_cast<ArgType>(fd.getFusionState(args.at(0)));
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 --- 24 | 23 --------------------------  0 |
+  //! | Dtype     | Arith Function Sig hash code     |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    result |= ((static_cast<size_t>(dtype_) & 0xff) << 24);
+    result |= (fusion_op_.target_type().hash_code() & 0xffffff);
+    return result;
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const CastOpRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      if (result) {
+        result = result &&
+            (fusion_op_.target_type() == child_ptr->fusion_op_.target_type());
+        if (Nvf::isDebugDumpEnabled(
+                Nvf::DebugDumpOption::PythonFrontendDebug)) {
+          std::cout << "\nCastOpRecord: " << name_ << " Target Type [self: 0x"
+                    << fusion_op_.target_type().name() << "] [other: 0x"
+                    << child_ptr->fusion_op_.target_type().name() << "]";
+        }
+        // IMPORTANT! you need to dereference the target pointer in order
+        // to match the function
+        result = result &&
+            (*fusion_op_
+                  .template target<OutType (*)(Nvf::DataType, ArgType)>() ==
+             *child_ptr->fusion_op_
+                  .template target<OutType (*)(Nvf::DataType, ArgType)>());
+        if (Nvf::isDebugDumpEnabled(
+                Nvf::DebugDumpOption::PythonFrontendDebug)) {
+          std::cout
+              << " Target  Ptr [self: 0x" << std::hex
+              << (size_t)*fusion_op_
+                     .template target<OutType (*)(Nvf::DataType, ArgType)>()
+              << "] [other: 0x" << std::hex
+              << (size_t)*child_ptr->fusion_op_
+                     .template target<OutType (*)(Nvf::DataType, ArgType)>()
+              << "]\n";
+        }
+        result = result && (dtype_ == child_ptr->dtype_);
+      }
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    auto arg = dynamic_cast<ArgType>(fd.getFusionState(args_.at(0).index));
     auto output = fusion_op_(dtype_, arg);
-    fd.setFusionState(outputs.at(0), output);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
+
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    os << ", dtype=" << dtypeToPyString(dtype_);
+    if (close_function) {
+      os << ")";
+    }
   }
 
  private:
   //! nvFuser arith function signature
-  std::function<OutType(NvfDataType, ArgType)> fusion_op_;
+  std::function<OutType(Nvf::DataType, ArgType)> fusion_op_;
   //! Type to cast to.
-  NvfDataType dtype_;
+  Nvf::DataType dtype_;
 };
 
 //! Specialized Record Functor for recording FusionDefinition constant state.
 
 template <typename ExprType, typename ValueType>
 struct ConstantRecord : RecordFunctor {
-  ConstantRecord(std::vector<size_t> _outputs, ValueType val)
-      : RecordFunctor({}, std::move(_outputs)), value_(val) {}
+  ConstantRecord(std::vector<State> _outputs, ValueType val)
+      : RecordFunctor(
+            {},
+            std::move(_outputs),
+            "define_constant",
+            RecordType::Constant),
+        value_(val) {}
   virtual ~ConstantRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new ConstantRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    NvfVal* output = IrBuilder::create<ExprType>(value_);
-    fd.setFusionState(outputs.at(0), output);
+  //! Going to start out hashing nothing extra since hashing a complex number
+  //! seems complicated.  Initially, the thought was to simply static cast the
+  //! value_
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    return result;
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const ConstantRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      result = result && (value_ == child_ptr->value_);
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    Nvf::Val* output = Nvf::IrBuilder::create<ExprType>(value_);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
+
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    if (std::is_same<ValueType, bool>::value) {
+      os << (value_ ? "True" : "False");
+    } else {
+      os << std::showpoint << value_;
+    }
+
+    if (close_function) {
+      os << ")";
+    }
   }
 
  private:
@@ -234,67 +669,209 @@
   ValueType value_;
 };
 
+//! Specialized Record Functor for recording FusionDefinition End.
+//! The accompanying Fusion Cache Entry holds a Fusion Object.
+
+struct EndRecord : RecordFunctor {
+  EndRecord() : RecordFunctor({}, {}, "end", RecordType::End) {}
+  virtual ~EndRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new EndRecord(*this);
+  }
+
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 ---------------------------------------  0 |
+  //! | None                                          |
+  virtual size_t hash() const final {
+    return RecordFunctor::hash();
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (dynamic_cast<const EndRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {}
+};
+
 //! Specialized Record Functor for recording FusionDefinition input tensors.
 
-struct InputTensorRecord : RecordFunctor {
-  InputTensorRecord(
-      std::vector<size_t> _outputs,
+struct TensorRecord : RecordFunctor {
+  TensorRecord(
+      std::vector<State> _outputs,
       std::vector<int64_t> _symbolic_sizes,
       std::vector<bool> _contiguous_info,
-      NvfDataType _dtype,
+      Nvf::DataType _dtype,
       bool _is_cpu = false)
-      : RecordFunctor({}, std::move(_outputs)),
-        symbolic_sizes(std::move(_symbolic_sizes)),
-        contiguous_info(std::move(_contiguous_info)),
-        dtype(_dtype),
-        is_cpu(_is_cpu) {}
-  virtual ~InputTensorRecord() = default;
+      : RecordFunctor(
+            {},
+            std::move(_outputs),
+            "define_tensor",
+            RecordType::Tensor),
+        symbolic_sizes_(std::move(_symbolic_sizes)),
+        contiguous_info_(std::move(_contiguous_info)),
+        dtype_(_dtype),
+        is_cpu_(_is_cpu) {}
+  virtual ~TensorRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new TensorRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    auto tv = TensorViewBuilder()
-                  .ndims(symbolic_sizes.size())
-                  .contiguity(contiguous_info)
-                  .shape(symbolic_sizes)
-                  .dtype(dtype)
-                  .build();
-
-    if (symbolic_sizes.empty() && is_cpu) {
-      tv->setCpuScalar(true);
-    } else {
-      TORCH_CHECK(!is_cpu, "cpu non-scalar tensor is not supported");
+  //! Child specific hash function in lower 32 bits.
+  //! |  31  | 30 --- 24 | 23 --------- 12 | 11 ---------  0 |
+  //! | CPU? | Dtype     | Symbolic Sizes  | Contiguous Info |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    size_t ssize_hash = 0;
+    for (size_t i = 0; i < symbolic_sizes_.size(); ++i) {
+      size_t ssize = 0;
+      if (symbolic_sizes_[i] == -1) {
+        ssize = 1;
+      }
+      ssize_hash |= (ssize << (symbolic_sizes_.size() - 1 - i));
+    }
+    size_t contig_hash = 0;
+    for (size_t i = 0; i < contiguous_info_.size(); ++i) {
+      contig_hash |= (contiguous_info_[i] << (contiguous_info_.size() - 1 - i));
     }
 
-    fd.setFusionState(outputs.at(0), tv);
+    result |= ((static_cast<size_t>(is_cpu_) & 0x1) << 31);
+    result |= ((static_cast<size_t>(dtype_) & 0x7f) << 24);
+    return result | ((ssize_hash & 0xfff) << 12) | (contig_hash & 0xfff);
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const TensorRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      result = result && (dtype_ == child_ptr->dtype_);
+      result = result && (is_cpu_ == child_ptr->is_cpu_);
+      if (result) {
+        result =
+            ((symbolic_sizes_.size() == child_ptr->symbolic_sizes_.size()) &&
+             (contiguous_info_.size() == child_ptr->contiguous_info_.size()));
+        if (result) {
+          for (size_t i = 0; i < symbolic_sizes_.size(); ++i) {
+            if (symbolic_sizes_[i] != child_ptr->symbolic_sizes_[i]) {
+              result = false;
+              break;
+            }
+          }
+        }
+        if (result) {
+          for (size_t i = 0; i < contiguous_info_.size(); ++i) {
+            if (contiguous_info_[i] != child_ptr->contiguous_info_[i]) {
+              result = false;
+              break;
+            }
+          }
+        }
+      }
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    auto tv = Nvf::TensorViewBuilder()
+                  .ndims(symbolic_sizes_.size())
+                  .contiguity(contiguous_info_)
+                  .shape(symbolic_sizes_)
+                  .dtype(dtype_)
+                  .build();
+
+    if (symbolic_sizes_.empty() && is_cpu_) {
+      tv->setCpuScalar(true);
+    } else {
+      TORCH_CHECK(!is_cpu_, "CPU non-scalar tensor is not supported!");
+    }
+
+    fd.setFusionState(outputs_.at(0).index, tv);
     fd.addInput(tv);
   }
 
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    os << "symbolic_sizes=[";
+    bool first_arg = true;
+    for (auto ss : symbolic_sizes_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      os << ss;
+    }
+    os << "], contiguous=[";
+    first_arg = true;
+    for (auto ci : contiguous_info_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      if (ci) {
+        os << "True";
+      } else {
+        os << "False";
+      }
+    }
+    os << "], dtype=" << dtypeToPyString(dtype_);
+    if (close_function) {
+      os << ")";
+    }
+  }
+
+ private:
   //! A vector of tensor dimension sizes.
   //! This vector only captures sizes of -1 or 1 to indicate a symbolic
   //! dimension (-1) or a broadcast dimension (1).
-  std::vector<int64_t> symbolic_sizes;
+  std::vector<int64_t> symbolic_sizes_;
   //! A vector to indicate whether the a tensor dimension is contiguous
   //! with the dimension just to its right.
-  std::vector<bool> contiguous_info;
+  std::vector<bool> contiguous_info_;
   //! Tensor data type.
-  NvfDataType dtype;
-  //! Tensor data type.
-  bool is_cpu;
+  Nvf::DataType dtype_;
+  //! Notes a scalar CPU Tensor
+  bool is_cpu_;
 };
 
 //! Specialized Record Functor for recording FusionDefinition outputs.
 
 template <class OutputType>
 struct OutputRecord : RecordFunctor {
-  OutputRecord(std::vector<size_t> _args)
-      : RecordFunctor(std::move(_args), {}) {}
+  OutputRecord(std::vector<State> _args)
+      : RecordFunctor(std::move(_args), {}, "add_output", RecordType::Output) {}
   virtual ~OutputRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new OutputRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    auto input = fd.getFusionState(args.at(0));
+  //! Nothing extra necessary in hash
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 ---------------------------------------  0 |
+  //! | None                                          |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    return result;
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const OutputRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    auto input = fd.getFusionState(args_.at(0).index);
 
     // With C++17, this statement should be "if constexpr"
-    if (std::is_same<OutputType, NvfTensorView>::value) {
-      fd.addOutput(input->template as<NvfTensorView>());
+    if (std::is_same<OutputType, Nvf::TensorView>::value) {
+      fd.addOutput(input->template as<Nvf::TensorView>());
     } else {
       fd.addOutput(input);
     }
@@ -305,92 +882,315 @@
 
 struct ReductionOpRecord : RecordFunctor {
   ReductionOpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
-      std::function<
-          NvfTensorView*(NvfTensorView*, std::vector<int>&, bool, NvfDataType)>
-          fusion_op,
+      std::vector<State> _args,
+      std::vector<State> _outputs,
+      std::string _name,
+      std::function<Nvf::TensorView*(
+          Nvf::TensorView*,
+          const std::vector<int>&,
+          bool,
+          Nvf::DataType)> fusion_op,
       std::vector<int> axes,
       bool keep_dim,
-      NvfDataType dtype)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
+      Nvf::DataType dtype)
+      : RecordFunctor(
+            std::move(_args),
+            std::move(_outputs),
+            _name,
+            RecordType::ReductionOp),
         fusion_op_(fusion_op),
         axes_(std::move(axes)),
         keep_dim_(keep_dim),
         dtype_(dtype) {}
   virtual ~ReductionOpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new ReductionOpRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    auto arg = fd.getFusionState(args.at(0))->template as<NvfTensorView>();
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 -- 28 | 27 --- 20 | 19 -----------------  0 |
+  //! | keep_dim | Dtype     | Axes Hash               |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    size_t axes_hash = 0;
+    // Normally I would make a little endian hash of the axes but I do not
+    // know the size of the tensor based on just the record information.
+    for (size_t i = 0; i < axes_.size(); ++i) {
+      axes_hash |= (1 << axes_[i]);
+    }
+
+    return result | (static_cast<size_t>(keep_dim_) << 28) |
+        ((static_cast<size_t>(dtype_) & 0xff) << 20) | (axes_hash & 0xfffff);
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const ReductionOpRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      if (result) {
+        result = result &&
+            (fusion_op_.target_type() == child_ptr->fusion_op_.target_type());
+        if (Nvf::isDebugDumpEnabled(
+                Nvf::DebugDumpOption::PythonFrontendDebug)) {
+          std::cout << "\nReductionOpRecord: " << name_
+                    << " Target Type [self: 0x"
+                    << fusion_op_.target_type().name() << "] [other: 0x"
+                    << child_ptr->fusion_op_.target_type().name() << "]";
+        }
+        // IMPORTANT! you need to dereference the target pointer in order
+        // to match the function
+        result = result &&
+            (*fusion_op_.template target<
+                 Nvf::
+                     TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>() ==
+             *child_ptr->fusion_op_.template target<
+                 Nvf::
+                     TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>());
+        if (Nvf::isDebugDumpEnabled(
+                Nvf::DebugDumpOption::PythonFrontendDebug)) {
+          std::cout
+              << " Target  Ptr [self: 0x" << std::hex
+              << (size_t)*fusion_op_.template target<
+                     Nvf::
+                         TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>()
+              << "] [other: 0x" << std::hex
+              << (size_t)*child_ptr->fusion_op_.template target<
+                     Nvf::
+                         TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>()
+              << "]\n";
+        }
+        result = result && (keep_dim_ == child_ptr->keep_dim_);
+        result = result && (dtype_ == child_ptr->dtype_);
+        if (result) {
+          result = (axes_.size() == child_ptr->axes_.size());
+          if (result) {
+            for (size_t i = 0; i < axes_.size(); ++i) {
+              if (axes_[i] != child_ptr->axes_[i]) {
+                result = false;
+                break;
+              }
+            }
+          }
+        }
+      }
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    auto arg =
+        fd.getFusionState(args_.at(0).index)->template as<Nvf::TensorView>();
     auto output = fusion_op_(arg, axes_, keep_dim_, dtype_);
-    fd.setFusionState(outputs.at(0), output);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
+
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    os << ", axes=[";
+    bool first_arg = true;
+    for (auto axis : axes_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      os << axis;
+    }
+    os << "]";
+    os << ", keepdim=" << (keep_dim_ ? "True" : "False");
+    os << ", dtype=" << dtypeToPyString(dtype_);
+    if (close_function) {
+      os << ")";
+    }
   }
 
  private:
   //! nvFuser arith function signature for a given reduction operation
-  std::function<
-      NvfTensorView*(NvfTensorView*, std::vector<int>&, bool, NvfDataType)>
+  std::function<Nvf::TensorView*(
+      Nvf::TensorView*,
+      const std::vector<int>&,
+      bool,
+      Nvf::DataType)>
       fusion_op_;
   //! The tensor dimensions to reduce
   std::vector<int> axes_;
   //! Indicates whether to keep the reduced dimension(s).
   bool keep_dim_;
   //! The output data type.
-  NvfDataType dtype_;
+  Nvf::DataType dtype_;
 };
 
 //! Specialized Record Functor for recording FusionDefinition input scalars.
 
 struct ScalarRecord : RecordFunctor {
-  ScalarRecord(std::vector<size_t> _outputs, NvfDataType dtype)
-      : RecordFunctor({}, std::move(_outputs)), dtype_(dtype) {}
+  ScalarRecord(std::vector<State> _outputs, Nvf::DataType dtype)
+      : RecordFunctor(
+            {},
+            std::move(_outputs),
+            "define_scalar",
+            RecordType::Scalar),
+        dtype_(dtype) {}
   virtual ~ScalarRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new ScalarRecord(*this);
+  }
 
-  void operator()(FusionDefinition& fd) final {
-    NvfVal* output = nullptr;
-    if (dtype_ == NvfDataType::Double) {
-      output = IrBuilder::create<torch::jit::fuser::cuda::Double>();
-    } else if (dtype_ == NvfDataType::ComplexDouble) {
-      output = IrBuilder::create<torch::jit::fuser::cuda::ComplexDouble>();
-    } else if (dtype_ == NvfDataType::Bool) {
-      output = IrBuilder::create<torch::jit::fuser::cuda::Bool>();
-    } else if (dtype_ == NvfDataType::Int) {
-      output = IrBuilder::create<torch::jit::fuser::cuda::Int>();
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 ---------------------------------------  0 |
+  //! | Dtype                                         |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    return result | (static_cast<size_t>(dtype_) & 0xffffffff);
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const ScalarRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      result = result && (dtype_ == child_ptr->dtype_);
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {
+    Nvf::Val* output = nullptr;
+    if (dtype_ == Nvf::DataType::Double) {
+      output = Nvf::IrBuilder::create<Nvf::Double>();
+    } else if (dtype_ == Nvf::DataType::ComplexDouble) {
+      output = Nvf::IrBuilder::create<Nvf::ComplexDouble>();
+    } else if (dtype_ == Nvf::DataType::Bool) {
+      output = Nvf::IrBuilder::create<Nvf::Bool>();
+    } else if (dtype_ == Nvf::DataType::Int) {
+      output = Nvf::IrBuilder::create<Nvf::Int>();
     } else {
       TORCH_CHECK(false, "Dtype is not supported:", dtype_);
     }
     fd.addInput(output);
-    fd.setFusionState(outputs.at(0), output);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
+
+  virtual void print(std::ostream& os, bool close_function = true) const {
+    RecordFunctor::print(os, false);
+    os << "dtype=" << dtypeToPyString(dtype_);
+    if (close_function) {
+      os << ")";
+    }
   }
 
  private:
   //! Scalar data type.
-  NvfDataType dtype_;
+  Nvf::DataType dtype_;
 };
 
-//! Specialized Record Functor for the FusionDefinition's var op.
+//! Specialized Record Functor for recording FusionDefinition Start.
+//! There should only ever be one instance of this Record in the
+//! Fusion Cache.
 
-struct VarianceOpRecord : RecordFunctor {
-  VarianceOpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
+struct StartRecord : RecordFunctor {
+  StartRecord() : RecordFunctor({}, {}, "start", RecordType::Start) {}
+  virtual ~StartRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new StartRecord(*this);
+  }
+
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 ---------------------------------------  0 |
+  //! | None                                          |
+  virtual size_t hash() const final {
+    return RecordFunctor::hash();
+  }
+
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (dynamic_cast<const StartRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+    }
+    return result;
+  }
+
+  virtual void operator()(FusionDefinition& fd) final {}
+};
+
+//! Specialized Record Functors for Normalization based ops.
+
+struct NormOpRecord : RecordFunctor {
+  NormOpRecord(
+      std::vector<State> args,
+      std::vector<State> outputs,
+      std::string name,
+      RecordType type,
       std::vector<int>& axes,
       int64_t correction,
       bool keep_dim)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
+      : RecordFunctor(std::move(args), std::move(outputs), name, type),
         axes_(axes),
         correction_(correction),
         keep_dim_(keep_dim) {}
-  virtual ~VarianceOpRecord() = default;
+  virtual ~NormOpRecord() = default;
+  virtual RecordFunctor* clone() = 0;
 
-  void operator()(FusionDefinition& fd) final {
-    auto arg = fd.getFusionState(args.at(0))->as<NvfTensorView>();
-    auto output =
-        torch::jit::fuser::cuda::variance(arg, axes_, correction_, keep_dim_);
-    fd.setFusionState(outputs.at(0), output);
+  // I am skipping the bassel's correction value in the hash because
+  // I suspect we might change it to a bool from a 64-bit value
+  //! Child specific hash function in lower 32 bits.
+  //! | 31 -- 28 | 27 -----------------------------  0 |
+  //! | keep_dim | Axes Hash                           |
+  virtual size_t hash() const final {
+    auto result = RecordFunctor::hash();
+    size_t axes_hash = 0;
+    // Normally I would make a little endian hash of the axes but I do not
+    // know the size of the tensor based on just the record information.
+    for (size_t i = 0; i < axes_.size(); ++i) {
+      axes_hash |= (1 << axes_[i]);
+    }
+    return result | (static_cast<size_t>(keep_dim_) << 28) |
+        (axes_hash & 0xfffffff);
   }
 
- private:
+  virtual bool operator==(const RecordFunctor& other) const final {
+    auto result = false;
+    if (auto child_ptr = dynamic_cast<const NormOpRecord*>(&other)) {
+      result = RecordFunctor::operator==(other);
+      result = result && (correction_ == child_ptr->correction_);
+      result = result && (keep_dim_ == child_ptr->keep_dim_);
+      if (result) {
+        result = (axes_.size() == child_ptr->axes_.size());
+        if (result) {
+          for (size_t i = 0; i < axes_.size(); ++i) {
+            if (axes_[i] != child_ptr->axes_[i]) {
+              result = false;
+              break;
+            }
+          }
+        }
+      }
+    }
+    return result;
+  }
+
+  //! Each NormOp Child should define the operator() to build the IR
+  virtual void operator()(FusionDefinition& fd) = 0;
+
+  virtual void print(std::ostream& os, bool close_function = true) const final {
+    RecordFunctor::print(os, false);
+    os << ", axes=[";
+    bool first_arg = true;
+    for (auto axis : axes_) {
+      if (first_arg) {
+        first_arg = false;
+      } else {
+        os << ", ";
+      }
+      os << axis;
+    }
+    os << "]";
+    os << ", correction=" << correction_;
+    os << ", keepdim=" << (keep_dim_ ? "True" : "False");
+    if (close_function) {
+      os << ")";
+    }
+  }
+
+ protected:
   //! Dimensions of tensor to reduce for variance calculation
   std::vector<int> axes_;
   //! Bessel's correction value
@@ -399,34 +1199,88 @@
   bool keep_dim_;
 };
 
-struct VarianceMeanOpRecord : RecordFunctor {
-  VarianceMeanOpRecord(
-      std::vector<size_t> _args,
-      std::vector<size_t> _outputs,
-      std::vector<int>& dims,
+struct VarianceOpRecord : NormOpRecord {
+  VarianceOpRecord(
+      std::vector<State> args,
+      std::vector<State> outputs,
+      std::vector<int>& axes,
       int64_t correction,
-      bool keepdim)
-      : RecordFunctor(std::move(_args), std::move(_outputs)),
-        dims_(dims),
-        correction_(correction),
-        keepdim_(keepdim) {}
-  virtual ~VarianceMeanOpRecord() = default;
-
-  void operator()(FusionDefinition& fd) final {
-    auto arg = fd.getFusionState(args.at(0))->as<NvfTensorView>();
-    auto output = torch::jit::fuser::cuda::variance_mean(
-        arg, dims_, correction_, keepdim_);
-    fd.setFusionState(outputs.at(0), output.var);
-    fd.setFusionState(outputs.at(1), output.mean);
+      bool keep_dim)
+      : NormOpRecord(
+            std::move(args),
+            std::move(outputs),
+            "ops.var",
+            RecordType::VarianceOp,
+            axes,
+            correction,
+            keep_dim) {}
+  virtual ~VarianceOpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new VarianceOpRecord(*this);
   }
 
- private:
-  //! Dimensions of tensor to reduce for variance calculation
-  std::vector<int> dims_;
-  //! Bessel's correction value
-  int64_t correction_;
-  //! Indicates whether to keep the reduced dimension(s).
-  bool keepdim_;
+  virtual void operator()(FusionDefinition& fd) final {
+    auto arg = fd.getFusionState(args_.at(0).index)->as<Nvf::TensorView>();
+    auto output = Nvf::variance(arg, axes_, correction_, keep_dim_);
+    fd.setFusionState(outputs_.at(0).index, output);
+  }
+};
+
+//! VarianceMean requires a separate Record because nvFuser defines the output
+//! of var_mean as a custom struct.
+struct VarianceMeanOpRecord : NormOpRecord {
+  VarianceMeanOpRecord(
+      std::vector<State> args,
+      std::vector<State> outputs,
+      std::vector<int>& axes,
+      int64_t correction,
+      bool keep_dim)
+      : NormOpRecord(
+            std::move(args),
+            std::move(outputs),
+            "ops.var_mean",
+            RecordType::VarianceMeanOp,
+            axes,
+            correction,
+            keep_dim) {}
+  virtual ~VarianceMeanOpRecord() = default;
+  virtual RecordFunctor* clone() final {
+    return new VarianceMeanOpRecord(*this);
+  }
+
+  void operator()(FusionDefinition& fd) final {
+    auto arg = fd.getFusionState(args_.at(0).index)->as<Nvf::TensorView>();
+    auto output = Nvf::variance_mean(arg, axes_, correction_, keep_dim_);
+    fd.setFusionState(outputs_.at(0).index, output.var);
+    fd.setFusionState(outputs_.at(1).index, output.mean);
+  }
 };
 
 } // namespace nvfuser
+
+//! Creating the template specialized hash and equal_to functions for a
+//! RecordFunctor object in order to use hash maps (unordered_maps) in STL.
+namespace std {
+using namespace nvfuser;
+
+template <>
+struct hash<RecordFunctor*> {
+  size_t operator()(const RecordFunctor* p) const {
+    TORCH_CHECK(p, "The RecordFunctor Pointer for hashing is null!");
+    return p->hash();
+  }
+};
+template <>
+struct equal_to<RecordFunctor*>
+    : public binary_function<RecordFunctor*, RecordFunctor*, bool> {
+  bool operator()(const RecordFunctor* p, const RecordFunctor* q) const {
+    TORCH_CHECK(
+        p,
+        "The RecordFunctor Pointer on the lhs of an equality check is null!");
+    TORCH_CHECK(
+        q,
+        "The RecordFunctor Pointer on the rhs of an equality check is null!");
+    return p->operator==(*q);
+  }
+};
+} // namespace std
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp
index 2d21156..e09a24c 100644
--- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp
@@ -4,10 +4,13 @@
 #include <c10/util/ArrayRef.h>
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_builder.h>
 #include <torch/csrc/jit/codegen/cuda/ops/composite.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h>
 #include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h>
 #include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
 #include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
@@ -24,27 +27,41 @@
   auto nvfuser = m.def_submodule("_nvfuser");
 
   //! DataTypes supported by nvFuser in the FusionDefinition
-  py::enum_<NvfDataType>(nvfuser, "DataType")
-      .value("Double", NvfDataType::Double)
-      .value("Float", NvfDataType::Float)
-      .value("Half", NvfDataType::Half)
-      .value("Int", NvfDataType::Int)
-      .value("Int32", NvfDataType::Int32)
-      .value("Bool", NvfDataType::Bool)
-      .value("BFloat16", NvfDataType::BFloat16)
-      .value("ComplexFloat", NvfDataType::ComplexFloat)
-      .value("ComplexDouble", NvfDataType::ComplexDouble)
-      .value("Null", NvfDataType::Null);
+  py::enum_<Nvf::DataType>(nvfuser, "DataType")
+      .value("Double", Nvf::DataType::Double)
+      .value("Float", Nvf::DataType::Float)
+      .value("Half", Nvf::DataType::Half)
+      .value("Int", Nvf::DataType::Int)
+      .value("Int32", Nvf::DataType::Int32)
+      .value("Bool", Nvf::DataType::Bool)
+      .value("BFloat16", Nvf::DataType::BFloat16)
+      .value("ComplexFloat", Nvf::DataType::ComplexFloat)
+      .value("ComplexDouble", Nvf::DataType::ComplexDouble)
+      .value("Null", Nvf::DataType::Null);
 
-  //! Binding an object that owns a FusionExecutorCache instance and provides
-  //! an interface
-  //! \todo This object will be removed when a FusionManager is added
-  //! containing a cache.
-  py::class_<nvfuser::FusionOwner> fusion(nvfuser, "Fusion");
+  //! Binding the FusionCache that holds a cache of Fusions
+  //! This is only bound to provide an interface to get the number of fusions
+  //! that are cached.
+  py::class_<nvfuser::FusionCache> fusion_cache(nvfuser, "FusionCache");
+  fusion_cache
+      .def_static(
+          "get",
+          &nvfuser::FusionCache::get,
+          py::arg("max_fusions") = int(8192),
+          py::return_value_policy::reference)
+      .def("num_fusions", &nvfuser::FusionCache::numFusions)
+      .def("print_stats", [](nvfuser::FusionCache& self) {
+        self.print(std::cout);
+      });
+
+  py::class_<nvfuser::FusionInterface> fusion(nvfuser, "Fusion");
   fusion.def(py::init<>())
+      .def(py::init<size_t>(), py::arg("fusion_id"))
+      .def("define", &nvfuser::FusionInterface::define)
+      .def("defined", &nvfuser::FusionInterface::defined)
       .def(
           "execute",
-          [](nvfuser::FusionOwner& self, const py::iterable& iter) {
+          [](nvfuser::FusionInterface& self, const py::iterable& iter) {
             std::vector<IValue> inputs;
             for (py::handle obj : iter) {
               inputs.push_back(toIValue(obj, c10::AnyType::get()));
@@ -52,10 +69,8 @@
             return self.execute(inputs);
           },
           py::return_value_policy::reference)
-      .def("print_ir", [](nvfuser::FusionOwner& self) { self.printIr(); })
-      .def("print_kernel", [](nvfuser::FusionOwner& self) {
-        self.printKernel();
-      });
+      .def("id", &nvfuser::FusionInterface::id)
+      .def("print", &nvfuser::FusionInterface::print);
 
   //! These are the FusionDefinition supported object types that are either
   //! defined as inputs or the output of an operation.
@@ -66,11 +81,18 @@
   //! define the set the operations and connections between operations for
   //! nvFuser to create.
   py::class_<nvfuser::FusionDefinition> fusion_def(nvfuser, "FusionDefinition");
-  fusion_def.def(py::init<nvfuser::FusionOwner*>())
+  fusion_def
+      .def(
+          py::init<nvfuser::FusionInterface*, int>(),
+          py::arg("fusion"),
+          py::arg("max_length") = int(256))
       .def_readwrite("ops", &nvfuser::FusionDefinition::ops)
       .def(
           "__enter__",
           [](nvfuser::FusionDefinition& self) -> nvfuser::FusionDefinition* {
+            // Instrumentation to mark the beginning of a FusionDefinition
+            Nvf::inst::Trace::instance()->beginEvent(
+                "FusionDefinition Context Manager");
             return self.enter();
           })
       .def(
@@ -78,47 +100,99 @@
           [](nvfuser::FusionDefinition& self,
              void* exc_type,
              void* exc_value,
-             void* traceback) { self.exit(); })
+             void* traceback) {
+            self.exit();
+            // Mark the end of a FusionDefinition Context Manager
+            Nvf::inst::Trace::instance()->endEvent(nullptr);
+          })
       .def(
-          "add_output",
-          [](nvfuser::FusionDefinition& self, nvfuser::Scalar* output) {
-            self.defineRecord(
-                new nvfuser::OutputRecord<NvfVal>({output->index}));
+          "__str__",
+          [](nvfuser::FusionDefinition& self) {
+            std::stringstream ss;
+            self.print(ss);
+            return ss.str();
           })
       .def(
           "add_output",
-          [](nvfuser::FusionDefinition& self, nvfuser::Tensor* output) {
-            self.defineRecord(
-                new nvfuser::OutputRecord<NvfTensorView>({output->index}));
+          [](nvfuser::FusionDefinition& self, nvfuser::Scalar output) {
+            FUSER_PERF_SCOPE("FusionDefinition.add_output (scalar)");
+            self.defineRecord(new nvfuser::OutputRecord<Nvf::Val>(
+                {self.recordingState(output())}));
+          })
+      .def(
+          "add_output",
+          [](nvfuser::FusionDefinition& self, nvfuser::Tensor output) {
+            FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)");
+            self.defineRecord(new nvfuser::OutputRecord<Nvf::TensorView>(
+                {self.recordingState(output())}));
           })
       .def(
           "define_tensor",
           [](nvfuser::FusionDefinition& self,
              size_t ndims,
-             NvfDataType dtype = NvfDataType::Float) -> nvfuser::Tensor* {
+             Nvf::DataType dtype = Nvf::DataType::Float,
+             bool is_cpu = false) -> nvfuser::Tensor {
+            FUSER_PERF_SCOPE("FusionDefinition.define_tensor (simple)");
             std::vector<int64_t> maybe_symbolic_sizes(ndims, -1);
             ;
             std::vector<bool> contig_info(ndims, false);
 
-            nvfuser::Tensor* out = self.defineTensor();
-            self.defineRecord(new nvfuser::InputTensorRecord(
-                {out->index},
+            nvfuser::Tensor out = self.defineTensor();
+            self.defineRecord(new nvfuser::TensorRecord(
+                {self.recordingState(out())},
                 std::move(maybe_symbolic_sizes),
                 std::move(contig_info),
-                dtype));
+                dtype,
+                is_cpu));
 
             return out;
           },
           py::arg("ndims"),
-          py::arg("dtype") = torch::jit::fuser::cuda::DataType::Float,
+          py::arg("dtype") = Nvf::DataType::Float,
+          py::arg("is_cpu") = false,
           py::return_value_policy::reference)
       .def(
           "define_tensor",
           [](nvfuser::FusionDefinition& self,
-             std::vector<int64_t> sizes,
-             std::vector<int64_t> strides,
-             NvfDataType dtype = NvfDataType::Float,
-             bool is_cpu = false) -> nvfuser::Tensor* {
+             std::vector<int64_t>& symbolic_sizes,
+             std::vector<bool>& contiguous,
+             Nvf::DataType dtype = Nvf::DataType::Float,
+             bool is_cpu = false) -> nvfuser::Tensor {
+            FUSER_PERF_SCOPE("FusionDefinition.define_tensor (default)");
+
+            for (size_t i = 0; i < symbolic_sizes.size(); ++i) {
+              TORCH_CHECK(
+                  symbolic_sizes[i] == -1 || symbolic_sizes[i] == 1,
+                  "The value ",
+                  symbolic_sizes[i],
+                  " at index ",
+                  i,
+                  " was neither broadcast(1) or symbolic(-1).");
+            }
+
+            nvfuser::Tensor out = self.defineTensor();
+            self.defineRecord(new nvfuser::TensorRecord(
+                {self.recordingState(out())},
+                symbolic_sizes,
+                contiguous,
+                dtype,
+                is_cpu));
+
+            return out;
+          },
+          py::arg("symbolic_sizes"),
+          py::arg("contiguous"),
+          py::arg("dtype") = Nvf::DataType::Float,
+          py::arg("is_cpu") = false,
+          py::return_value_policy::reference)
+      .def(
+          "define_tensor",
+          [](nvfuser::FusionDefinition& self,
+             std::vector<int64_t>& sizes,
+             std::vector<int64_t>& strides,
+             Nvf::DataType dtype = Nvf::DataType::Float,
+             bool is_cpu = false) -> nvfuser::Tensor {
+            FUSER_PERF_SCOPE("FusionDefinition.define_tensor (integration)");
             TORCH_CHECK(
                 sizes.size() == strides.size(),
                 "The number of sizes does not match the number of strides.",
@@ -155,9 +229,9 @@
               }
             }
 
-            nvfuser::Tensor* out = self.defineTensor();
-            self.defineRecord(new nvfuser::InputTensorRecord(
-                {out->index},
+            nvfuser::Tensor out = self.defineTensor();
+            self.defineRecord(new nvfuser::TensorRecord(
+                {self.recordingState(out())},
                 std::move(maybe_symbolic_sizes),
                 std::move(contig_info),
                 dtype,
@@ -167,64 +241,64 @@
           },
           py::arg("sizes"),
           py::arg("strides"),
-          py::arg("dtype") = NvfDataType::Float,
+          py::arg("dtype") = Nvf::DataType::Float,
           py::arg("is_cpu") = false,
           py::return_value_policy::reference)
       .def(
           "define_constant",
-          [](nvfuser::FusionDefinition& self, double val) -> nvfuser::Scalar* {
-            nvfuser::Scalar* out = self.defineScalar();
-            self.defineRecord(
-                new nvfuser::
-                    ConstantRecord<torch::jit::fuser::cuda::Double, double>(
-                        {out->index}, val));
+          [](nvfuser::FusionDefinition& self, double val) -> nvfuser::Scalar {
+            FUSER_PERF_SCOPE("FusionDefinition.define_constant (double)");
+            nvfuser::Scalar out = self.defineScalar();
+            self.defineRecord(new nvfuser::ConstantRecord<Nvf::Double, double>(
+                {self.recordingState(out())}, val));
             return out;
           },
           py::return_value_policy::reference)
       .def(
           "define_constant",
           [](nvfuser::FusionDefinition& self,
-             std::complex<double> val) -> nvfuser::Scalar* {
-            nvfuser::Scalar* out = self.defineScalar();
-            self.defineRecord(new nvfuser::ConstantRecord<
-                              torch::jit::fuser::cuda::ComplexDouble,
-                              c10::complex<double>>(
-                {out->index}, static_cast<c10::complex<double>>(val)));
+             std::complex<double> val) -> nvfuser::Scalar {
+            FUSER_PERF_SCOPE("FusionDefinition.define_constant (complex)");
+            nvfuser::Scalar out = self.defineScalar();
+            self.defineRecord(
+                new nvfuser::
+                    ConstantRecord<Nvf::ComplexDouble, c10::complex<double>>(
+                        {self.recordingState(out())},
+                        static_cast<c10::complex<double>>(val)));
             return out;
           },
           py::return_value_policy::reference)
       .def(
           "define_constant",
-          [](nvfuser::FusionDefinition& self, bool val) -> nvfuser::Scalar* {
-            nvfuser::Scalar* out = self.defineScalar();
-            self.defineRecord(
-                new nvfuser::
-                    ConstantRecord<torch::jit::fuser::cuda::Bool, bool>(
-                        {out->index}, val));
+          [](nvfuser::FusionDefinition& self, bool val) -> nvfuser::Scalar {
+            FUSER_PERF_SCOPE("FusionDefinition.define_constant (bool)");
+            nvfuser::Scalar out = self.defineScalar();
+            self.defineRecord(new nvfuser::ConstantRecord<Nvf::Bool, bool>(
+                {self.recordingState(out())}, val));
             return out;
           },
           py::return_value_policy::reference)
       .def(
           "define_constant",
-          [](nvfuser::FusionDefinition& self, int64_t val) -> nvfuser::Scalar* {
-            nvfuser::Scalar* out = self.defineScalar();
-            self.defineRecord(
-                new nvfuser::
-                    ConstantRecord<torch::jit::fuser::cuda::Int, int64_t>(
-                        {out->index}, val));
+          [](nvfuser::FusionDefinition& self, int64_t val) -> nvfuser::Scalar {
+            FUSER_PERF_SCOPE("FusionDefinition.define_constant (int)");
+            nvfuser::Scalar out = self.defineScalar();
+            self.defineRecord(new nvfuser::ConstantRecord<Nvf::Int, int64_t>(
+                {self.recordingState(out())}, val));
             return out;
           },
           py::return_value_policy::reference)
       .def(
           "define_scalar",
           [](nvfuser::FusionDefinition& self,
-             NvfDataType dtype = torch::jit::fuser::cuda::DataType::Double)
-              -> nvfuser::Scalar* {
-            nvfuser::Scalar* out = self.defineScalar();
-            self.defineRecord(new nvfuser::ScalarRecord({out->index}, dtype));
+             Nvf::DataType dtype = Nvf::DataType::Double) -> nvfuser::Scalar {
+            FUSER_PERF_SCOPE("FusionDefinition.define_scalar");
+            nvfuser::Scalar out = self.defineScalar();
+            self.defineRecord(
+                new nvfuser::ScalarRecord({self.recordingState(out())}, dtype));
             return out;
           },
-          py::arg("dtype") = torch::jit::fuser::cuda::DataType::Double,
+          py::arg("dtype") = Nvf::DataType::Double,
           py::return_value_policy::reference);
 
   //! The Operators class is a nested class of FusionDefinition to allow the
@@ -240,35 +314,39 @@
   nvf_ops.def(py::init<nvfuser::FusionDefinition*>());
 
   // ******************** INSERT OP BINDINGS BELOW HERE ********************
-
-#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name)                  \
-  nvf_ops.def(                                                            \
-      op_str,                                                             \
-      [](nvfuser::FusionDefinition::Operators& self,                      \
-         nvfuser::Tensor* input) -> nvfuser::Tensor* {                    \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \
-        self.fusion_definition->defineRecord(                             \
-            new nvfuser::OpRecord<NvfTensorView*, NvfTensorView*>(        \
-                {input->index},                                           \
-                {output->index},                                          \
-                static_cast<NvfTensorView* (*)(NvfTensorView*)>(          \
-                    torch::jit::fuser::cuda::op_name)));                  \
-        return output;                                                    \
-      },                                                                  \
-      py::return_value_policy::reference);                                \
-  nvf_ops.def(                                                            \
-      op_str,                                                             \
-      [](nvfuser::FusionDefinition::Operators& self,                      \
-         nvfuser::Scalar* input) -> nvfuser::Scalar* {                    \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \
-        self.fusion_definition->defineRecord(                             \
-            new nvfuser::OpRecord<NvfVal*, NvfVal*>(                      \
-                {input->index},                                           \
-                {output->index},                                          \
-                static_cast<NvfVal* (*)(NvfVal*)>(                        \
-                    torch::jit::fuser::cuda::op_name)));                  \
-        return output;                                                    \
-      },                                                                  \
+#define OP_PREFIX "Operators."
+#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name)               \
+  nvf_ops.def(                                                         \
+      op_str,                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                   \
+         nvfuser::Tensor input) -> nvfuser::Tensor {                   \
+        FUSER_PERF_SCOPE("Operators." op_str);                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;        \
+        nvfuser::Tensor output = fd->defineTensor();                   \
+        fd->defineRecord(                                              \
+            new nvfuser::OpRecord<Nvf::TensorView*, Nvf::TensorView*>( \
+                {fd->recordingState(input())},                         \
+                {fd->recordingState(output())},                        \
+                ("ops." op_str),                                       \
+                static_cast<Nvf::TensorView* (*)(Nvf::TensorView*)>(   \
+                    Nvf::op_name)));                                   \
+        return output;                                                 \
+      },                                                               \
+      py::return_value_policy::reference);                             \
+  nvf_ops.def(                                                         \
+      op_str,                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                   \
+         nvfuser::Scalar input) -> nvfuser::Scalar {                   \
+        FUSER_PERF_SCOPE("Operators." op_str);                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;        \
+        nvfuser::Scalar output = fd->defineScalar();                   \
+        fd->defineRecord(new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*>(  \
+            {fd->recordingState(input())},                             \
+            {fd->recordingState(output())},                            \
+            ("ops." op_str),                                           \
+            static_cast<Nvf::Val* (*)(Nvf::Val*)>(Nvf::op_name)));     \
+        return output;                                                 \
+      },                                                               \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_UNARY_OP("abs", abs)
@@ -317,68 +395,85 @@
   NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
 #undef NVFUSER_PYTHON_BINDING_UNARY_OP
 
-#define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name)                    \
-  nvf_ops.def(                                                               \
-      op_str,                                                                \
-      [](nvfuser::FusionDefinition::Operators& self,                         \
-         nvfuser::Tensor* arg1,                                              \
-         nvfuser::Tensor* arg2) -> nvfuser::Tensor* {                        \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();    \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<          \
-                                             NvfTensorView*,                 \
-                                             NvfTensorView*,                 \
-                                             NvfTensorView*>(                \
-            {arg1->index, arg2->index},                                      \
-            {output->index},                                                 \
-            static_cast<NvfTensorView* (*)(NvfTensorView*, NvfTensorView*)>( \
-                torch::jit::fuser::cuda::op_name)));                         \
-        return output;                                                       \
-      },                                                                     \
-      py::return_value_policy::reference);                                   \
-  nvf_ops.def(                                                               \
-      op_str,                                                                \
-      [](nvfuser::FusionDefinition::Operators& self,                         \
-         nvfuser::Tensor* arg1,                                              \
-         nvfuser::Scalar* arg2) -> nvfuser::Tensor* {                        \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();    \
-        self.fusion_definition->defineRecord(                                \
-            new nvfuser::OpRecord<NvfTensorView*, NvfTensorView*, NvfVal*>(  \
-                {arg1->index, arg2->index},                                  \
-                {output->index},                                             \
-                static_cast<NvfTensorView* (*)(NvfTensorView*, NvfVal*)>(    \
-                    torch::jit::fuser::cuda::op_name)));                     \
-        return output;                                                       \
-      },                                                                     \
-      py::return_value_policy::reference);                                   \
-  nvf_ops.def(                                                               \
-      op_str,                                                                \
-      [](nvfuser::FusionDefinition::Operators& self,                         \
-         nvfuser::Scalar* arg1,                                              \
-         nvfuser::Tensor* arg2) -> nvfuser::Tensor* {                        \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();    \
-        self.fusion_definition->defineRecord(                                \
-            new nvfuser::OpRecord<NvfTensorView*, NvfVal*, NvfTensorView*>(  \
-                {arg1->index, arg2->index},                                  \
-                {output->index},                                             \
-                static_cast<NvfTensorView* (*)(NvfVal*, NvfTensorView*)>(    \
-                    torch::jit::fuser::cuda::op_name)));                     \
-        return output;                                                       \
-      },                                                                     \
-      py::return_value_policy::reference);                                   \
-  nvf_ops.def(                                                               \
-      op_str,                                                                \
-      [](nvfuser::FusionDefinition::Operators& self,                         \
-         nvfuser::Scalar* arg1,                                              \
-         nvfuser::Scalar* arg2) -> nvfuser::Scalar* {                        \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar();    \
-        self.fusion_definition->defineRecord(                                \
-            new nvfuser::OpRecord<NvfVal*, NvfVal*, NvfVal*>(                \
-                {arg1->index, arg2->index},                                  \
-                {output->index},                                             \
-                static_cast<NvfVal* (*)(NvfVal*, NvfVal*)>(                  \
-                    torch::jit::fuser::cuda::op_name)));                     \
-        return output;                                                       \
-      },                                                                     \
+#define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name)                   \
+  nvf_ops.def(                                                              \
+      op_str,                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                        \
+         nvfuser::Tensor arg1,                                              \
+         nvfuser::Tensor arg2) -> nvfuser::Tensor {                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;             \
+        nvfuser::Tensor output = fd->defineTensor();                        \
+        fd->defineRecord(new nvfuser::OpRecord<                             \
+                         Nvf::TensorView*,                                  \
+                         Nvf::TensorView*,                                  \
+                         Nvf::TensorView*>(                                 \
+            {fd->recordingState(arg1()), fd->recordingState(arg2())},       \
+            {fd->recordingState(output())},                                 \
+            ("ops." op_str),                                                \
+            static_cast<                                                    \
+                Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*)>(  \
+                Nvf::op_name)));                                            \
+        return output;                                                      \
+      },                                                                    \
+      py::return_value_policy::reference);                                  \
+  nvf_ops.def(                                                              \
+      op_str,                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                        \
+         nvfuser::Tensor arg1,                                              \
+         nvfuser::Scalar arg2) -> nvfuser::Tensor {                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;             \
+        nvfuser::Tensor output = fd->defineTensor();                        \
+        fd->defineRecord(new nvfuser::OpRecord<                             \
+                         Nvf::TensorView*,                                  \
+                         Nvf::TensorView*,                                  \
+                         Nvf::Val*>(                                        \
+            {fd->recordingState(arg1()), fd->recordingState(arg2())},       \
+            {fd->recordingState(output())},                                 \
+            ("ops." op_str),                                                \
+            static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>( \
+                Nvf::op_name)));                                            \
+        return output;                                                      \
+      },                                                                    \
+      py::return_value_policy::reference);                                  \
+  nvf_ops.def(                                                              \
+      op_str,                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                        \
+         nvfuser::Scalar arg1,                                              \
+         nvfuser::Tensor arg2) -> nvfuser::Tensor {                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;             \
+        nvfuser::Tensor output = fd->defineTensor();                        \
+        fd->defineRecord(new nvfuser::OpRecord<                             \
+                         Nvf::TensorView*,                                  \
+                         Nvf::Val*,                                         \
+                         Nvf::TensorView*>(                                 \
+            {fd->recordingState(arg1()), fd->recordingState(arg2())},       \
+            {fd->recordingState(output())},                                 \
+            ("ops." op_str),                                                \
+            static_cast<Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*)>( \
+                Nvf::op_name)));                                            \
+        return output;                                                      \
+      },                                                                    \
+      py::return_value_policy::reference);                                  \
+  nvf_ops.def(                                                              \
+      op_str,                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                        \
+         nvfuser::Scalar arg1,                                              \
+         nvfuser::Scalar arg2) -> nvfuser::Scalar {                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;             \
+        nvfuser::Scalar output = fd->defineScalar();                        \
+        fd->defineRecord(                                                   \
+            new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*>(         \
+                {fd->recordingState(arg1()), fd->recordingState(arg2())},   \
+                {fd->recordingState(output())},                             \
+                ("ops." op_str),                                            \
+                static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*)>(           \
+                    Nvf::op_name)));                                        \
+        return output;                                                      \
+      },                                                                    \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_BINARY_OP("add", add)
@@ -403,234 +498,309 @@
   NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_right_shift", bitwise_left_shift)
 #undef NVFUSER_PYTHON_BINDING_BINARY_OP
 
-#define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name)           \
-  nvf_ops.def(                                                                 \
-      op_str,                                                                  \
-      [](nvfuser::FusionDefinition::Operators& self,                           \
-         nvfuser::Tensor* arg1,                                                \
-         nvfuser::Tensor* arg2,                                                \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                          \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();      \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<            \
-                                             NvfTensorView*,                   \
-                                             NvfTensorView*,                   \
-                                             NvfTensorView*,                   \
-                                             NvfVal*>(                         \
-            {arg1->index, arg2->index, arg3->index},                           \
-            {output->index},                                                   \
-            static_cast<                                                       \
-                NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfVal*)>(  \
-                torch::jit::fuser::cuda::op_name)));                           \
-        return output;                                                         \
-      },                                                                       \
-      py::return_value_policy::reference);                                     \
-  nvf_ops.def(                                                                 \
-      op_str,                                                                  \
-      [](nvfuser::FusionDefinition::Operators& self,                           \
-         nvfuser::Tensor* arg1,                                                \
-         nvfuser::Scalar* arg2,                                                \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                          \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();      \
-        self.fusion_definition->defineRecord(                                  \
-            new nvfuser::                                                      \
-                OpRecord<NvfTensorView*, NvfTensorView*, NvfVal*, NvfVal*>(    \
-                    {arg1->index, arg2->index, arg3->index},                   \
-                    {output->index},                                           \
-                    static_cast<                                               \
-                        NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*)>( \
-                        torch::jit::fuser::cuda::op_name)));                   \
-        return output;                                                         \
-      },                                                                       \
-      py::return_value_policy::reference);                                     \
-  nvf_ops.def(                                                                 \
-      op_str,                                                                  \
-      [](nvfuser::FusionDefinition::Operators& self,                           \
-         nvfuser::Scalar* arg1,                                                \
-         nvfuser::Tensor* arg2,                                                \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                          \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();      \
-        self.fusion_definition->defineRecord(                                  \
-            new nvfuser::                                                      \
-                OpRecord<NvfTensorView*, NvfVal*, NvfTensorView*, NvfVal*>(    \
-                    {arg1->index, arg2->index, arg3->index},                   \
-                    {output->index},                                           \
-                    static_cast<                                               \
-                        NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfVal*)>( \
-                        torch::jit::fuser::cuda::op_name)));                   \
-        return output;                                                         \
-      },                                                                       \
-      py::return_value_policy::reference);                                     \
-  nvf_ops.def(                                                                 \
-      op_str,                                                                  \
-      [](nvfuser::FusionDefinition::Operators& self,                           \
-         nvfuser::Scalar* arg1,                                                \
-         nvfuser::Scalar* arg2,                                                \
-         nvfuser::Scalar* arg3) -> nvfuser::Scalar* {                          \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar();      \
-        self.fusion_definition->defineRecord(                                  \
-            new nvfuser::OpRecord<NvfVal*, NvfVal*, NvfVal*, NvfVal*>(         \
-                {arg1->index, arg2->index, arg3->index},                       \
-                {output->index},                                               \
-                static_cast<NvfVal* (*)(NvfVal*, NvfVal*, NvfVal*)>(           \
-                    torch::jit::fuser::cuda::op_name)));                       \
-        return output;                                                         \
-      },                                                                       \
+#define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name)                 \
+  nvf_ops.def(                                                                       \
+      op_str,                                                                        \
+      [](nvfuser::FusionDefinition::Operators& self,                                 \
+         nvfuser::Tensor arg1,                                                       \
+         nvfuser::Tensor arg2,                                                       \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                                  \
+        FUSER_PERF_SCOPE("Operators." op_str);                                       \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                      \
+        nvfuser::Tensor output = fd->defineTensor();                                 \
+        fd->defineRecord(new nvfuser::OpRecord<                                      \
+                         Nvf::TensorView*,                                           \
+                         Nvf::TensorView*,                                           \
+                         Nvf::TensorView*,                                           \
+                         Nvf::Val*>(                                                 \
+            {fd->recordingState(arg1()),                                             \
+             fd->recordingState(arg2()),                                             \
+             fd->recordingState(arg3())},                                            \
+            {fd->recordingState(output())},                                          \
+            ("ops." op_str),                                                         \
+            static_cast<                                                             \
+                Nvf::                                                                \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \
+                Nvf::op_name)));                                                     \
+        return output;                                                               \
+      },                                                                             \
+      py::return_value_policy::reference);                                           \
+  nvf_ops.def(                                                                       \
+      op_str,                                                                        \
+      [](nvfuser::FusionDefinition::Operators& self,                                 \
+         nvfuser::Tensor arg1,                                                       \
+         nvfuser::Scalar arg2,                                                       \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                                  \
+        FUSER_PERF_SCOPE("Operators." op_str);                                       \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                      \
+        nvfuser::Tensor output = fd->defineTensor();                                 \
+        fd->defineRecord(new nvfuser::OpRecord<                                      \
+                         Nvf::TensorView*,                                           \
+                         Nvf::TensorView*,                                           \
+                         Nvf::Val*,                                                  \
+                         Nvf::Val*>(                                                 \
+            {fd->recordingState(arg1()),                                             \
+             fd->recordingState(arg2()),                                             \
+             fd->recordingState(arg3())},                                            \
+            {fd->recordingState(output())},                                          \
+            ("ops." op_str),                                                         \
+            static_cast<                                                             \
+                Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>(       \
+                Nvf::op_name)));                                                     \
+        return output;                                                               \
+      },                                                                             \
+      py::return_value_policy::reference);                                           \
+  nvf_ops.def(                                                                       \
+      op_str,                                                                        \
+      [](nvfuser::FusionDefinition::Operators& self,                                 \
+         nvfuser::Scalar arg1,                                                       \
+         nvfuser::Tensor arg2,                                                       \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                                  \
+        FUSER_PERF_SCOPE("Operators." op_str);                                       \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                      \
+        nvfuser::Tensor output = fd->defineTensor();                                 \
+        fd->defineRecord(new nvfuser::OpRecord<                                      \
+                         Nvf::TensorView*,                                           \
+                         Nvf::Val*,                                                  \
+                         Nvf::TensorView*,                                           \
+                         Nvf::Val*>(                                                 \
+            {fd->recordingState(arg1()),                                             \
+             fd->recordingState(arg2()),                                             \
+             fd->recordingState(arg3())},                                            \
+            {fd->recordingState(output())},                                          \
+            ("ops." op_str),                                                         \
+            static_cast<                                                             \
+                Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>(       \
+                Nvf::op_name)));                                                     \
+        return output;                                                               \
+      },                                                                             \
+      py::return_value_policy::reference);                                           \
+  nvf_ops.def(                                                                       \
+      op_str,                                                                        \
+      [](nvfuser::FusionDefinition::Operators& self,                                 \
+         nvfuser::Scalar arg1,                                                       \
+         nvfuser::Scalar arg2,                                                       \
+         nvfuser::Scalar arg3) -> nvfuser::Scalar {                                  \
+        FUSER_PERF_SCOPE("Operators." op_str);                                       \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                      \
+        nvfuser::Scalar output = fd->defineScalar();                                 \
+        fd->defineRecord(                                                            \
+            new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>(       \
+                {fd->recordingState(arg1()),                                         \
+                 fd->recordingState(arg2()),                                         \
+                 fd->recordingState(arg3())},                                        \
+                {fd->recordingState(output())},                                      \
+                ("ops." op_str),                                                     \
+                static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>(         \
+                    Nvf::op_name)));                                                 \
+        return output;                                                               \
+      },                                                                             \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("add_alpha", add_alpha)
   NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("sub_alpha", sub_alpha)
 #undef NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP
 
-#define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name)                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Scalar* arg1,                                                      \
-         nvfuser::Scalar* arg2,                                                      \
-         nvfuser::Scalar* arg3) -> nvfuser::Scalar* {                                \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar();            \
-        self.fusion_definition->defineRecord(                                        \
-            new nvfuser::OpRecord<NvfVal*, NvfVal*, NvfVal*, NvfVal*>(               \
-                {arg1->index, arg2->index, arg3->index},                             \
-                {output->index},                                                     \
-                static_cast<NvfVal* (*)(NvfVal*, NvfVal*, NvfVal*)>(                 \
-                    torch::jit::fuser::cuda::op_name)));                             \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Tensor* arg1,                                                      \
-         nvfuser::Tensor* arg2,                                                      \
-         nvfuser::Tensor* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                  \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*>(                        \
-            {arg1->index, arg2->index, arg3->index},                                 \
-            {output->index},                                                         \
-            static_cast<                                                             \
-                NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfTensorView*)>( \
-                torch::jit::fuser::cuda::op_name)));                                 \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Tensor* arg1,                                                      \
-         nvfuser::Tensor* arg2,                                                      \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                  \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*,                         \
-                                             NvfVal*>(                               \
-            {arg1->index, arg2->index, arg3->index},                                 \
-            {output->index},                                                         \
-            static_cast<                                                             \
-                NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfVal*)>(        \
-                torch::jit::fuser::cuda::op_name)));                                 \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Tensor* arg1,                                                      \
-         nvfuser::Scalar* arg2,                                                      \
-         nvfuser::Tensor* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                  \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*,                         \
-                                             NvfVal*,                                \
-                                             NvfTensorView*>(                        \
-            {arg1->index, arg2->index, arg3->index},                                 \
-            {output->index},                                                         \
-            static_cast<                                                             \
-                NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfTensorView*)>(        \
-                torch::jit::fuser::cuda::op_name)));                                 \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Scalar* arg1,                                                      \
-         nvfuser::Tensor* arg2,                                                      \
-         nvfuser::Tensor* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                  \
-                                             NvfTensorView*,                         \
-                                             NvfVal*,                                \
-                                             NvfTensorView*,                         \
-                                             NvfTensorView*>(                        \
-            {arg1->index, arg2->index, arg3->index},                                 \
-            {output->index},                                                         \
-            static_cast<                                                             \
-                NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfTensorView*)>(        \
-                torch::jit::fuser::cuda::op_name)));                                 \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Scalar* arg1,                                                      \
-         nvfuser::Scalar* arg2,                                                      \
-         nvfuser::Tensor* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(                                        \
-            new nvfuser::                                                            \
-                OpRecord<NvfTensorView*, NvfVal*, NvfVal*, NvfTensorView*>(          \
-                    {arg1->index, arg2->index, arg3->index},                         \
-                    {output->index},                                                 \
-                    static_cast<                                                     \
-                        NvfTensorView* (*)(NvfVal*, NvfVal*, NvfTensorView*)>(       \
-                        torch::jit::fuser::cuda::op_name)));                         \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Tensor* arg1,                                                      \
-         nvfuser::Scalar* arg2,                                                      \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(                                        \
-            new nvfuser::                                                            \
-                OpRecord<NvfTensorView*, NvfTensorView*, NvfVal*, NvfVal*>(          \
-                    {arg1->index, arg2->index, arg3->index},                         \
-                    {output->index},                                                 \
-                    static_cast<                                                     \
-                        NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*)>(       \
-                        torch::jit::fuser::cuda::op_name)));                         \
-        return output;                                                               \
-      },                                                                             \
-      py::return_value_policy::reference);                                           \
-  nvf_ops.def(                                                                       \
-      op_str,                                                                        \
-      [](nvfuser::FusionDefinition::Operators& self,                                 \
-         nvfuser::Scalar* arg1,                                                      \
-         nvfuser::Tensor* arg2,                                                      \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                                \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();            \
-        self.fusion_definition->defineRecord(                                        \
-            new nvfuser::                                                            \
-                OpRecord<NvfTensorView*, NvfVal*, NvfTensorView*, NvfVal*>(          \
-                    {arg1->index, arg2->index, arg3->index},                         \
-                    {output->index},                                                 \
-                    static_cast<                                                     \
-                        NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfVal*)>(       \
-                        torch::jit::fuser::cuda::op_name)));                         \
-        return output;                                                               \
-      },                                                                             \
+#define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name)                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Scalar arg1,                                                              \
+         nvfuser::Scalar arg2,                                                              \
+         nvfuser::Scalar arg3) -> nvfuser::Scalar {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Scalar output = fd->defineScalar();                                        \
+        fd->defineRecord(                                                                   \
+            new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>(              \
+                {fd->recordingState(arg1()),                                                \
+                 fd->recordingState(arg2()),                                                \
+                 fd->recordingState(arg3())},                                               \
+                {fd->recordingState(output())},                                             \
+                ("ops." op_str),                                                            \
+                static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>(                \
+                    Nvf::op_name)));                                                        \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Tensor arg1,                                                              \
+         nvfuser::Tensor arg2,                                                              \
+         nvfuser::Tensor arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*>(                                                 \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::                                                                       \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::TensorView*)>( \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Tensor arg1,                                                              \
+         nvfuser::Tensor arg2,                                                              \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*>(                                                        \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::                                                                       \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>(        \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Tensor arg1,                                                              \
+         nvfuser::Scalar arg2,                                                              \
+         nvfuser::Tensor arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*,                                                         \
+                         Nvf::TensorView*>(                                                 \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::                                                                       \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::TensorView*)>(        \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Scalar arg1,                                                              \
+         nvfuser::Tensor arg2,                                                              \
+         nvfuser::Tensor arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*,                                                         \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*>(                                                 \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::                                                                       \
+                    TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::TensorView*)>(        \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Scalar arg1,                                                              \
+         nvfuser::Scalar arg2,                                                              \
+         nvfuser::Tensor arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*,                                                         \
+                         Nvf::Val*,                                                         \
+                         Nvf::TensorView*>(                                                 \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::TensorView* (*)(Nvf::Val*, Nvf::Val*, Nvf::TensorView*)>(              \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Tensor arg1,                                                              \
+         nvfuser::Scalar arg2,                                                              \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*,                                                         \
+                         Nvf::Val*>(                                                        \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>(              \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
+      py::return_value_policy::reference);                                                  \
+  nvf_ops.def(                                                                              \
+      op_str,                                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                                        \
+         nvfuser::Scalar arg1,                                                              \
+         nvfuser::Tensor arg2,                                                              \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                                         \
+        FUSER_PERF_SCOPE("Operators." op_str);                                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                             \
+        nvfuser::Tensor output = fd->defineTensor();                                        \
+        fd->defineRecord(new nvfuser::OpRecord<                                             \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*,                                                         \
+                         Nvf::TensorView*,                                                  \
+                         Nvf::Val*>(                                                        \
+            {fd->recordingState(arg1()),                                                    \
+             fd->recordingState(arg2()),                                                    \
+             fd->recordingState(arg3())},                                                   \
+            {fd->recordingState(output())},                                                 \
+            ("ops." op_str),                                                                \
+            static_cast<                                                                    \
+                Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>(              \
+                Nvf::op_name)));                                                            \
+        return output;                                                                      \
+      },                                                                                    \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_TERNARY_OP("lerp", lerp)
@@ -641,34 +811,46 @@
   nvf_ops.def(                                                                 \
       op_str,                                                                  \
       [](nvfuser::FusionDefinition::Operators& self,                           \
-         nvfuser::Scalar* arg1,                                                \
-         nvfuser::Scalar* arg2,                                                \
-         nvfuser::Scalar* arg3) -> nvfuser::Scalar* {                          \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar();      \
-        self.fusion_definition->defineRecord(                                  \
-            new nvfuser::OpRecord<NvfVal*, NvfVal*, NvfVal*, NvfVal*>(         \
-                {arg1->index, arg2->index, arg3->index},                       \
-                {output->index},                                               \
-                static_cast<NvfVal* (*)(NvfVal*, NvfVal*, NvfVal*)>(           \
-                    torch::jit::fuser::cuda::op_name)));                       \
+         nvfuser::Scalar arg1,                                                 \
+         nvfuser::Scalar arg2,                                                 \
+         nvfuser::Scalar arg3) -> nvfuser::Scalar {                            \
+        FUSER_PERF_SCOPE("Operators." op_str);                                 \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                \
+        nvfuser::Scalar output = fd->defineScalar();                           \
+        fd->defineRecord(                                                      \
+            new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \
+                {fd->recordingState(arg1()),                                   \
+                 fd->recordingState(arg2()),                                   \
+                 fd->recordingState(arg3())},                                  \
+                {fd->recordingState(output())},                                \
+                ("ops." op_str),                                               \
+                static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>(   \
+                    Nvf::op_name)));                                           \
         return output;                                                         \
       },                                                                       \
       py::return_value_policy::reference);                                     \
   nvf_ops.def(                                                                 \
       op_str,                                                                  \
       [](nvfuser::FusionDefinition::Operators& self,                           \
-         nvfuser::Tensor* arg1,                                                \
-         nvfuser::Scalar* arg2,                                                \
-         nvfuser::Scalar* arg3) -> nvfuser::Tensor* {                          \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();      \
-        self.fusion_definition->defineRecord(                                  \
-            new nvfuser::                                                      \
-                OpRecord<NvfTensorView*, NvfTensorView*, NvfVal*, NvfVal*>(    \
-                    {arg1->index, arg2->index, arg3->index},                   \
-                    {output->index},                                           \
-                    static_cast<                                               \
-                        NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*)>( \
-                        torch::jit::fuser::cuda::op_name)));                   \
+         nvfuser::Tensor arg1,                                                 \
+         nvfuser::Scalar arg2,                                                 \
+         nvfuser::Scalar arg3) -> nvfuser::Tensor {                            \
+        FUSER_PERF_SCOPE("Operators." op_str);                                 \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                \
+        nvfuser::Tensor output = fd->defineTensor();                           \
+        fd->defineRecord(new nvfuser::OpRecord<                                \
+                         Nvf::TensorView*,                                     \
+                         Nvf::TensorView*,                                     \
+                         Nvf::Val*,                                            \
+                         Nvf::Val*>(                                           \
+            {fd->recordingState(arg1()),                                       \
+             fd->recordingState(arg2()),                                       \
+             fd->recordingState(arg3())},                                      \
+            {fd->recordingState(output())},                                    \
+            ("ops." op_str),                                                   \
+            static_cast<                                                       \
+                Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \
+                Nvf::op_name)));                                               \
         return output;                                                         \
       },                                                                       \
       py::return_value_policy::reference);
@@ -677,206 +859,270 @@
   NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("threshold", threshold)
 #undef NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP
 
-#define NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP(op_str, op_name)                         \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Scalar* arg1,                                                               \
-         nvfuser::Scalar* arg2,                                                               \
-         nvfuser::Scalar* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Scalar* {                                         \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar();                     \
-        self.fusion_definition->defineRecord(                                                 \
-            new nvfuser::                                                                     \
-                OpRecord<NvfVal*, NvfVal*, NvfVal*, NvfVal*, NvfVal*>(                        \
-                    {arg1->index, arg2->index, arg3->index, arg4->index},                     \
-                    {output->index},                                                          \
-                    static_cast<                                                              \
-                        NvfVal* (*)(NvfVal*, NvfVal*, NvfVal*, NvfVal*)>(                     \
-                        torch::jit::fuser::cuda::op_name)));                                  \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Tensor* arg1,                                                               \
-         nvfuser::Tensor* arg2,                                                               \
-         nvfuser::Tensor* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*>(                                 \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfTensorView*, NvfVal*)>( \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Tensor* arg1,                                                               \
-         nvfuser::Tensor* arg2,                                                               \
-         nvfuser::Scalar* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfVal*>(                                        \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfVal*, NvfVal*)>(        \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Tensor* arg1,                                                               \
-         nvfuser::Scalar* arg2,                                                               \
-         nvfuser::Tensor* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*>(                                        \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfTensorView*, NvfVal*)>(        \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Scalar* arg1,                                                               \
-         nvfuser::Tensor* arg2,                                                               \
-         nvfuser::Tensor* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*>(                                        \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfTensorView*, NvfVal*)>(        \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Scalar* arg1,                                                               \
-         nvfuser::Scalar* arg2,                                                               \
-         nvfuser::Tensor* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfVal*,                                         \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*>(                                        \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfVal*, NvfVal*, NvfTensorView*, NvfVal*)>(               \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Tensor* arg1,                                                               \
-         nvfuser::Scalar* arg2,                                                               \
-         nvfuser::Scalar* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfVal*,                                         \
-                                             NvfVal*>(                                        \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*, NvfVal*)>(               \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
-      py::return_value_policy::reference);                                                    \
-  nvf_ops.def(                                                                                \
-      op_str,                                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                                          \
-         nvfuser::Scalar* arg1,                                                               \
-         nvfuser::Tensor* arg2,                                                               \
-         nvfuser::Scalar* arg3,                                                               \
-         nvfuser::Scalar* arg4) -> nvfuser::Tensor* {                                         \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();                     \
-        self.fusion_definition->defineRecord(new nvfuser::OpRecord<                           \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfTensorView*,                                  \
-                                             NvfVal*,                                         \
-                                             NvfVal*>(                                        \
-            {arg1->index, arg2->index, arg3->index, arg4->index},                             \
-            {output->index},                                                                  \
-            static_cast<                                                                      \
-                NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfVal*, NvfVal*)>(               \
-                torch::jit::fuser::cuda::op_name)));                                          \
-        return output;                                                                        \
-      },                                                                                      \
+#define NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP(op_str, op_name)                                  \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Scalar arg1,                                                                         \
+         nvfuser::Scalar arg2,                                                                         \
+         nvfuser::Scalar arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Scalar {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Scalar output = fd->defineScalar();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*)>(                            \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Tensor arg1,                                                                         \
+         nvfuser::Tensor arg2,                                                                         \
+         nvfuser::Tensor arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*>(                                                            \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Tensor arg1,                                                                         \
+         nvfuser::Tensor arg2,                                                                         \
+         nvfuser::Scalar arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>(        \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Tensor arg1,                                                                         \
+         nvfuser::Scalar arg2,                                                                         \
+         nvfuser::Tensor arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>(        \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Scalar arg1,                                                                         \
+         nvfuser::Tensor arg2,                                                                         \
+         nvfuser::Tensor arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>(        \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Scalar arg1,                                                                         \
+         nvfuser::Scalar arg2,                                                                         \
+         nvfuser::Tensor arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*,                                                                    \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::Val*, Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>(               \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Tensor arg1,                                                                         \
+         nvfuser::Scalar arg2,                                                                         \
+         nvfuser::Scalar arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*, Nvf::Val*)>(               \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
+      py::return_value_policy::reference);                                                             \
+  nvf_ops.def(                                                                                         \
+      op_str,                                                                                          \
+      [](nvfuser::FusionDefinition::Operators& self,                                                   \
+         nvfuser::Scalar arg1,                                                                         \
+         nvfuser::Tensor arg2,                                                                         \
+         nvfuser::Scalar arg3,                                                                         \
+         nvfuser::Scalar arg4) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                         \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                        \
+        nvfuser::Tensor output = fd->defineTensor();                                                   \
+        fd->defineRecord(new nvfuser::OpRecord<                                                        \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::TensorView*,                                                             \
+                         Nvf::Val*,                                                                    \
+                         Nvf::Val*>(                                                                   \
+            {fd->recordingState(arg1()),                                                               \
+             fd->recordingState(arg2()),                                                               \
+             fd->recordingState(arg3()),                                                               \
+             fd->recordingState(arg4())},                                                              \
+            {fd->recordingState(output())},                                                            \
+            ("ops." op_str),                                                                           \
+            static_cast<                                                                               \
+                Nvf::                                                                                  \
+                    TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>(               \
+                Nvf::op_name)));                                                                       \
+        return output;                                                                                 \
+      },                                                                                               \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul)
 #undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP
 
-#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name)                 \
-  nvf_ops.def(                                                               \
-      op_str,                                                                \
-      [](nvfuser::FusionDefinition::Operators& self,                         \
-         nvfuser::Tensor* arg,                                               \
-         const std::vector<int>& axes,                                       \
-         bool keep_dim,                                                      \
-         NvfDataType dtype) -> nvfuser::Tensor* {                            \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();    \
-        self.fusion_definition->defineRecord(new nvfuser::ReductionOpRecord( \
-            {arg->index},                                                    \
-            {output->index},                                                 \
-            torch::jit::fuser::cuda::op_name,                                \
-            axes,                                                            \
-            keep_dim,                                                        \
-            dtype));                                                         \
-        return output;                                                       \
-      },                                                                     \
-      py::arg("arg"),                                                        \
-      py::arg("axes"),                                                       \
-      py::arg("keep_dim"),                                                   \
-      py::arg("dtype") = torch::jit::fuser::cuda::DataType::Null,            \
+#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name)                                          \
+  nvf_ops.def(                                                                                        \
+      op_str,                                                                                         \
+      [](nvfuser::FusionDefinition::Operators& self,                                                  \
+         nvfuser::Tensor arg,                                                                         \
+         const std::vector<int>& axes,                                                                \
+         bool keepdim,                                                                                \
+         Nvf::DataType dtype) -> nvfuser::Tensor {                                                    \
+        FUSER_PERF_SCOPE("Operators." op_str);                                                        \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;                                       \
+        nvfuser::Tensor output = fd->defineTensor();                                                  \
+        fd->defineRecord(new nvfuser::ReductionOpRecord(                                              \
+            {fd->recordingState(arg())},                                                              \
+            {fd->recordingState(output())},                                                           \
+            ("ops." op_str),                                                                          \
+            static_cast<                                                                              \
+                Nvf::                                                                                 \
+                    TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>( \
+                Nvf::op_name),                                                                        \
+            axes,                                                                                     \
+            keepdim,                                                                                  \
+            dtype));                                                                                  \
+        return output;                                                                                \
+      },                                                                                              \
+      py::arg("arg"),                                                                                 \
+      py::arg("axes"),                                                                                \
+      py::arg("keepdim") = false,                                                                     \
+      py::arg("dtype") = Nvf::DataType::Null,                                                         \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_REDUCTION_OP("sum", sum)
@@ -884,38 +1130,48 @@
   NVFUSER_PYTHON_BINDING_REDUCTION_OP("min", min)
 #undef NVFUSER_PYTHON_BINDING_REDUCTION_OP
 
-#define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name)                       \
-  nvf_ops.def(                                                                \
-      op_str,                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                          \
-         nvfuser::Tensor* arg,                                                \
-         NvfDataType dtype) -> nvfuser::Tensor* {                             \
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();     \
-        self.fusion_definition->defineRecord(                                 \
-            new nvfuser::CastOpRecord<NvfTensorView*, NvfTensorView*>(        \
-                {arg->index},                                                 \
-                {output->index},                                              \
-                static_cast<NvfTensorView* (*)(NvfDataType, NvfTensorView*)>( \
-                    torch::jit::fuser::cuda::op_name),                        \
-                dtype));                                                      \
-        return output;                                                        \
-      },                                                                      \
-      py::return_value_policy::reference);                                    \
-  nvf_ops.def(                                                                \
-      op_str,                                                                 \
-      [](nvfuser::FusionDefinition::Operators& self,                          \
-         nvfuser::Scalar* arg,                                                \
-         NvfDataType dtype) -> nvfuser::Scalar* {                             \
-        nvfuser::Scalar* output = self.fusion_definition->defineScalar();     \
-        self.fusion_definition->defineRecord(                                 \
-            new nvfuser::CastOpRecord<NvfVal*, NvfVal*>(                      \
-                {arg->index},                                                 \
-                {output->index},                                              \
-                static_cast<NvfVal* (*)(NvfDataType, NvfVal*)>(               \
-                    torch::jit::fuser::cuda::op_name),                        \
-                dtype));                                                      \
-        return output;                                                        \
-      },                                                                      \
+#define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name)                     \
+  nvf_ops.def(                                                              \
+      op_str,                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                        \
+         nvfuser::Tensor arg,                                               \
+         Nvf::DataType dtype) -> nvfuser::Tensor {                          \
+        FUSER_PERF_SCOPE("Operators." op_str);                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;             \
+        nvfuser::Tensor output = fd->defineTensor();                        \
+        fd->defineRecord(                                                   \
+            new nvfuser::CastOpRecord<Nvf::TensorView*, Nvf::TensorView*>(  \
+                {fd->recordingState(arg())},                                \
+                {fd->recordingState(output())},                             \
+                ("ops." op_str),                                            \
+                static_cast<                                                \
+                    Nvf::TensorView* (*)(Nvf::DataType, Nvf::TensorView*)>( \
+                    Nvf::op_name),                                          \
+                dtype));                                                    \
+        return output;                                                      \
+      },                                                                    \
+      py::arg("arg"),                                                       \
+      py::arg("dtype"),                                                     \
+      py::return_value_policy::reference);                                  \
+  nvf_ops.def(                                                              \
+      op_str,                                                               \
+      [](nvfuser::FusionDefinition::Operators& self,                        \
+         nvfuser::Scalar arg,                                               \
+         Nvf::DataType dtype) -> nvfuser::Scalar {                          \
+        FUSER_PERF_SCOPE("Operators." op_str);                              \
+        nvfuser::FusionDefinition* fd = self.fusion_definition;             \
+        nvfuser::Scalar output = fd->defineScalar();                        \
+        fd->defineRecord(new nvfuser::CastOpRecord<Nvf::Val*, Nvf::Val*>(   \
+            {fd->recordingState(arg())},                                    \
+            {fd->recordingState(output())},                                 \
+            ("ops." op_str),                                                \
+            static_cast<Nvf::Val* (*)(Nvf::DataType, Nvf::Val*)>(           \
+                Nvf::op_name),                                              \
+            dtype));                                                        \
+        return output;                                                      \
+      },                                                                    \
+      py::arg("arg"),                                                       \
+      py::arg("dtype"),                                                     \
       py::return_value_policy::reference);
 
   NVFUSER_PYTHON_BINDING_CAST_OP("cast", castOp)
@@ -924,60 +1180,93 @@
   nvf_ops.def(
       "squeeze",
       [](nvfuser::FusionDefinition::Operators& self,
-         nvfuser::Tensor* arg,
+         nvfuser::Tensor arg,
          std::vector<int64_t>& original_shape,
-         int64_t dim) -> nvfuser::Tensor* {
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();
-        self.fusion_definition->defineRecord(new nvfuser::SqueezeOpRecord(
-            {arg->index}, {output->index}, original_shape, dim));
+         int64_t dim) -> nvfuser::Tensor {
+        FUSER_PERF_SCOPE("Operators.squeeze");
+        nvfuser::FusionDefinition* fd = self.fusion_definition;
+        nvfuser::Tensor output = fd->defineTensor();
+        fd->defineRecord(new nvfuser::SqueezeOpRecord(
+            {fd->recordingState(arg())},
+            {fd->recordingState(output())},
+            original_shape,
+            dim));
         return output;
       },
+      py::arg("arg"),
+      py::arg("original_shape"),
+      py::arg("dim"),
       py::return_value_policy::reference);
-
   nvf_ops.def(
       "var",
       [](nvfuser::FusionDefinition::Operators& self,
-         nvfuser::Tensor* arg,
+         nvfuser::Tensor arg,
          std::vector<int>& axes,
          int64_t correction,
-         bool keepdim) -> nvfuser::Tensor* {
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();
-        self.fusion_definition->defineRecord(new nvfuser::VarianceOpRecord(
-            {arg->index}, {output->index}, axes, correction, keepdim));
+         bool keepdim) -> nvfuser::Tensor {
+        FUSER_PERF_SCOPE("Operators.var");
+        nvfuser::FusionDefinition* fd = self.fusion_definition;
+        nvfuser::Tensor output = fd->defineTensor();
+        fd->defineRecord(new nvfuser::VarianceOpRecord(
+            {fd->recordingState(arg())},
+            {fd->recordingState(output())},
+            axes,
+            correction,
+            keepdim));
         return output;
       },
+      py::arg("arg"),
+      py::arg("axes"),
+      py::arg("correction"),
+      py::arg("keepdim") = false,
       py::return_value_policy::reference);
-
   nvf_ops.def(
       "var_mean",
       [](nvfuser::FusionDefinition::Operators& self,
-         nvfuser::Tensor* arg,
-         std::vector<int>& dims,
+         nvfuser::Tensor arg,
+         std::vector<int>& axes,
          int64_t correction,
          bool keepdim) -> decltype(auto) {
-        nvfuser::Tensor* var = self.fusion_definition->defineTensor();
-        nvfuser::Tensor* mean = self.fusion_definition->defineTensor();
-        self.fusion_definition->defineRecord(new nvfuser::VarianceMeanOpRecord(
-            {arg->index},
-            {var->index, mean->index},
-            dims,
+        FUSER_PERF_SCOPE("Operators.var_mean");
+        nvfuser::FusionDefinition* fd = self.fusion_definition;
+        nvfuser::Tensor var = fd->defineTensor();
+        nvfuser::Tensor mean = fd->defineTensor();
+        fd->defineRecord(new nvfuser::VarianceMeanOpRecord(
+            {fd->recordingState(arg())},
+            {fd->recordingState(var()), fd->recordingState(mean())},
+            axes,
             correction,
             keepdim));
         return std::make_tuple(var, mean);
       },
+      py::arg("arg"),
+      py::arg("axes"),
+      py::arg("correction"),
+      py::arg("keepdim") = false,
       py::return_value_policy::reference);
-
   nvf_ops.def(
       "broadcast_in_dim",
       [](nvfuser::FusionDefinition::Operators& self,
-         nvfuser::Tensor* arg,
+         nvfuser::Tensor arg,
          std::vector<int64_t>& output_shape,
-         std::vector<int64_t>& broadcast_dims) -> nvfuser::Tensor* {
-        nvfuser::Tensor* output = self.fusion_definition->defineTensor();
-        self.fusion_definition->defineRecord(new nvfuser::BroadcastOpRecord(
-            {arg->index}, {output->index}, output_shape, broadcast_dims));
+         std::vector<int64_t>& broadcast_dims) -> nvfuser::Tensor {
+        FUSER_PERF_SCOPE("Operators.broadcast_in_dim");
+        nvfuser::FusionDefinition* fd = self.fusion_definition;
+        TORCH_CHECK(
+            output_shape.size() >= broadcast_dims.size(),
+            "broadcast_dims vector size is too big for output shape!");
+        nvfuser::Tensor output = fd->defineTensor();
+        fd->defineRecord(new nvfuser::BroadcastOpRecord(
+            {fd->recordingState(arg())},
+            {fd->recordingState(output())},
+            "ops.broadcast_in_dim",
+            output_shape,
+            broadcast_dims));
         return output;
       },
+      py::arg("arg"),
+      py::arg("output_shape"),
+      py::arg("broadcast_dims"),
       py::return_value_policy::reference);
 }
 
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp
new file mode 100644
index 0000000..29040b1
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp
@@ -0,0 +1,257 @@
+#if defined(USE_CUDA)
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include <torch/torch.h>
+
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_cache.h>
+#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
+
+// Tests go in torch::jit
+namespace torch {
+namespace jit {
+
+using namespace nvfuser;
+using namespace torch::jit::fuser::cuda;
+
+// RUN CMD: bin/test_jit --gtest_filter="NVFuserTest*PyFusionCache*"
+TEST_F(NVFuserTest, PyFusionCache_CUDA) {
+  // Create a fusion manager with a maximum of 1 Fusion
+  FusionCache* fc = FusionCache::get(1);
+
+  // You should never get a nullptr
+  ASSERT_FALSE(fc == nullptr);
+
+  // Check that cache methods all assert when presented with a null record.
+  {
+    std::unique_ptr<RecordFunctor> null_record(nullptr);
+
+    try {
+      auto bad_cache_entry_ptr = fc->lookupFusionCacheEntry(null_record.get());
+      FAIL() << "Should trigger an assert when the record is looked up!";
+    } catch (...) {
+      SUCCEED();
+    }
+
+    try {
+      fc->traverseFusionCache(null_record.get());
+      FAIL() << "Should trigger an assert when the record is looked up!";
+    } catch (...) {
+      SUCCEED();
+    }
+
+    try {
+      fc->createFusionCacheEntry(null_record.get());
+      FAIL() << "Should trigger an assert when the record is looked up!";
+    } catch (...) {
+      SUCCEED();
+    }
+
+    try {
+      auto id = fc->createFusionCacheEntry(null_record.get());
+      FAIL() << "Should trigger an assert when the record is looked up!";
+    } catch (...) {
+      SUCCEED();
+    }
+  }
+
+  // Check that cache methods act appropriately when presenting a new
+  // record to an empty cache.
+  {
+    std::unique_ptr<RecordFunctor> test_record(new TensorRecord(
+        {State(0, StateType::Tensor)}, {3}, {true}, Nvf::DataType::Float));
+
+    // Check Methods prior to adding an entry to the cache
+
+    // Cache Lookup should not succeed becase no records are in the cache
+    try {
+      auto empty_cache_entry_ptr =
+          fc->lookupFusionCacheEntry(test_record.get());
+      ASSERT_TRUE(empty_cache_entry_ptr == c10::nullopt);
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during cache lookup!" << e.what();
+    }
+
+    // Traversal of the cache should fail because there is nothing to traverse
+    try {
+      fc->traverseFusionCache(test_record.get());
+      FAIL() << "Expected the cache traversal to fail!";
+    } catch (...) {
+      SUCCEED();
+    }
+
+    // Add a cache entry and check methods
+
+    try {
+      fc->createFusionCacheEntry(test_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert on Cache Entry creation!" << e.what();
+    }
+
+    try {
+      auto cache_entry_ptr = fc->lookupFusionCacheEntry(test_record.get());
+      ASSERT_FALSE(cache_entry_ptr == c10::nullopt);
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert on cache lookup!" << e.what();
+    }
+
+    try {
+      fc->traverseFusionCache(test_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert during Cache Traverse!" << e.what();
+    }
+
+    // Add a terminal cache entry and check methods
+
+    std::unique_ptr<RecordFunctor> end_record(new EndRecord());
+    try {
+      auto id = fc->createFusionCacheEntry(end_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert on Terminal Cache Entry creation!"
+             << e.what();
+    }
+
+    try {
+      fc->traverseFusionCache(end_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert while traversing to a Terminal Entry!"
+             << e.what();
+    }
+
+    try {
+      auto no_cache_entry_ptr = fc->lookupFusionCacheEntry(test_record.get());
+      FAIL() << "Expected an assert from a terminal entry!";
+    } catch (...) {
+      SUCCEED();
+    }
+
+    try {
+      fc->traverseFusionCache(test_record.get());
+      FAIL() << "Expected an assert from a terminal entry!";
+    } catch (...) {
+      SUCCEED();
+    }
+  }
+
+  // Setup cache for a new cache lookup
+  try {
+    fc->resetFusionCachePtr();
+    SUCCEED();
+  } catch (const std::exception& e) {
+    FAIL() << "Did not properly set cache to pointer to top of tree!"
+           << e.what();
+  }
+
+  // Check that cache methods act appropriately when presenting a new
+  // record to a cache with 1 fusion.
+  {
+    std::unique_ptr<RecordFunctor> cached_record(new TensorRecord(
+        {State(0, StateType::Tensor)}, {3}, {true}, Nvf::DataType::Float));
+    std::unique_ptr<RecordFunctor> new_record(
+        new ScalarRecord({State(1, StateType::Scalar)}, Nvf::DataType::Float));
+
+    try {
+      auto hit_cache_entry = fc->lookupFusionCacheEntry(cached_record.get());
+      ASSERT_FALSE(hit_cache_entry == c10::nullopt);
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Cache lookup unexpectedly asserted!" << e.what();
+    }
+
+    try {
+      fc->traverseFusionCache(cached_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Fusion cache traverse unexpectedly asserted!" << e.what();
+    }
+
+    try {
+      auto miss_cache_entry = fc->lookupFusionCacheEntry(new_record.get());
+      ASSERT_TRUE(miss_cache_entry == c10::nullopt);
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Cache lookup unexpectedly asserted!" << e.what();
+    }
+
+    try {
+      fc->createFusionCacheEntry(new_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert on Cache Entry creation!" << e.what();
+    }
+
+    try {
+      fc->traverseFusionCache(new_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Fusion cache traverse unexpectedly asserted!" << e.what();
+    }
+
+    std::unique_ptr<RecordFunctor> end_record(new EndRecord());
+    try {
+      auto id = fc->createFusionCacheEntry(end_record.get());
+      FAIL() << "Expected the cache to assert because it is full!";
+    } catch (...) {
+      SUCCEED();
+    }
+  }
+
+  // Setup cache for a new cache lookup
+  try {
+    fc->resetFusionCachePtr();
+    SUCCEED();
+  } catch (const std::exception& e) {
+    FAIL() << "Did not properly set cache to pointer to top of tree!"
+           << e.what();
+  }
+
+  // Verify proper cache lookup up of complete fusion already cached.
+  // This tends to flush out pointer problems in the cache.
+  {
+    std::unique_ptr<RecordFunctor> test_record(new TensorRecord(
+        {State(0, StateType::Tensor)}, {3}, {true}, Nvf::DataType::Float));
+    std::unique_ptr<RecordFunctor> dummy_record(new TensorRecord(
+        {State(0, StateType::Tensor)}, {3}, {true}, Nvf::DataType::Float));
+
+    try {
+      auto cache_entry_ptr = fc->lookupFusionCacheEntry(test_record.get());
+      ASSERT_FALSE(cache_entry_ptr == c10::nullopt);
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert on cache lookup!" << e.what();
+    }
+
+    try {
+      fc->traverseFusionCache(test_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert during Cache Traverse!" << e.what();
+    }
+
+    std::unique_ptr<RecordFunctor> end_record(new EndRecord());
+    try {
+      auto no_cache_entry_ptr = fc->lookupFusionCacheEntry(end_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert on cache lookup!" << e.what();
+    }
+
+    try {
+      fc->traverseFusionCache(end_record.get());
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "An unexpected assert while traversing to a Terminal Entry!"
+             << e.what();
+    }
+  }
+}
+
+} // namespace jit
+} // namespace torch
+#endif // #if defined(USE_CUDA)
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp
new file mode 100644
index 0000000..84aa4da
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp
@@ -0,0 +1,195 @@
+#if defined(USE_CUDA)
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include <torch/torch.h>
+
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_interface.h>
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
+#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
+
+// Tests go in torch::jit
+namespace torch {
+namespace jit {
+
+using namespace nvfuser;
+using namespace torch::jit::fuser::cuda;
+
+// RUN CMD: bin/test_jit --gtest_filter="NVFuserTest*FusionDefinition*"
+TEST_F(NVFuserTest, FusionDefinition_CUDA) {
+  // Test that the FusionDefinition asserts on max_length == 0
+  {
+    FusionDefinition fd(nullptr, 0);
+
+    try {
+      fd.enter();
+      FAIL() << "You should trigger an assert with 0 Records allowed!";
+    } catch (...) {
+      SUCCEED();
+    }
+  }
+
+  // Test that the FusionDefinition asserts on a null FusionManager ptr
+  {
+    FusionDefinition fd(nullptr, 5);
+
+    try {
+      fd.enter();
+      FAIL() << "You should trigger an assert with a null FusionInterface!";
+    } catch (...) {
+      SUCCEED();
+    }
+  }
+
+  // Create a new FusionDefinition that is not found in the cache
+  {
+    std::unique_ptr<FusionInterface> fusion =
+        std::make_unique<FusionInterface>();
+    FusionDefinition fd(fusion.get(), 4);
+
+    try {
+      fd.enter();
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert while entering FusionDefinition context! "
+             << e.what();
+    }
+
+    auto t0 = fd.defineTensor();
+    try {
+      fd.defineRecord(new TensorRecord(
+          {fd.recordingState(t0())}, {3}, {true}, Nvf::DataType::Float));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Tensor Record creation! " << e.what();
+    }
+
+    auto s1 = fd.defineScalar();
+    try {
+      fd.defineRecord(
+          new ScalarRecord({fd.recordingState(s1())}, Nvf::DataType::Double));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Scalar Record creation! " << e.what();
+    }
+
+    auto t2 = fd.defineTensor();
+    try {
+      fd.defineRecord(
+          new OpRecord<Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*>(
+              {fd.recordingState(t0()), fd.recordingState(s1())},
+              {fd.recordingState(t2())},
+              "ops.add",
+              static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>(
+                  Nvf::add)));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Add Record creation! " << e.what();
+    }
+
+    try {
+      fd.defineRecord(
+          new OutputRecord<Nvf::TensorView>({fd.recordingState(t2())}));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Output Record creation! " << e.what();
+    }
+
+    try {
+      fd.defineRecord(new OutputRecord<Nvf::Val>({fd.recordingState(s1())}));
+      FAIL() << "Expected an assert for too many records!";
+    } catch (...) {
+      SUCCEED();
+    }
+
+    try {
+      fd.exit();
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during creation of a new Fusion! "
+             << e.what();
+    }
+  }
+
+  // Look up a FusionDefinition with a defined Fusion
+  {
+    std::unique_ptr<FusionInterface> fusion =
+        std::make_unique<FusionInterface>(0);
+    FusionDefinition fd(fusion.get(), 1);
+
+    try {
+      fd.enter();
+      FAIL() << "You should trigger an assert with a defined FusionInterface!";
+    } catch (const std::exception& e) {
+      SUCCEED();
+    }
+  }
+
+  // Look up a FusionDefinition completely in the cache
+  {
+    std::unique_ptr<FusionInterface> fusion =
+        std::make_unique<FusionInterface>();
+    FusionDefinition fd(fusion.get(), 4);
+
+    try {
+      fd.enter();
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert while entering FusionDefinition context! "
+             << e.what();
+    }
+
+    auto t0 = fd.defineTensor();
+    try {
+      fd.defineRecord(new TensorRecord(
+          {fd.recordingState(t0())}, {3}, {true}, Nvf::DataType::Float));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Tensor Record creation! " << e.what();
+    }
+
+    auto s1 = fd.defineScalar();
+    try {
+      fd.defineRecord(
+          new ScalarRecord({fd.recordingState(s1())}, Nvf::DataType::Double));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Scalar Record creation! " << e.what();
+    }
+
+    auto t2 = fd.defineTensor();
+    try {
+      fd.defineRecord(
+          new OpRecord<Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*>(
+              {fd.recordingState(t0()), fd.recordingState(s1())},
+              {fd.recordingState(t2())},
+              "ops.add",
+              static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>(
+                  Nvf::add)));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Add Record creation! " << e.what();
+    }
+
+    try {
+      fd.defineRecord(
+          new OutputRecord<Nvf::TensorView>({fd.recordingState(t2())}));
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during Output Record creation! " << e.what();
+    }
+
+    try {
+      fd.exit();
+      SUCCEED();
+    } catch (const std::exception& e) {
+      FAIL() << "Unexpected assert during creation of a new Fusion! "
+             << e.what();
+    }
+  }
+}
+
+} // namespace jit
+} // namespace torch
+#endif // #if defined(USE_CUDA)
diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp
new file mode 100644
index 0000000..4778515
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp
@@ -0,0 +1,135 @@
+#if defined(USE_CUDA)
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include <torch/torch.h>
+
+#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h>
+#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
+
+// Tests go in torch::jit
+namespace torch {
+namespace jit {
+
+using namespace nvfuser;
+using namespace torch::jit::fuser::cuda;
+
+// RUN CMD: bin/test_jit --gtest_filter="NVFuserTest*RecordFunctorEquality*"
+TEST_F(NVFuserTest, RecordFunctorEquality_CUDA) {
+  // Getting the std::function matching correct is error prone so providing
+  // checks for OpRecord, CastOp, and ReductionOp that employ std::function
+  // matching.
+
+  // OpRecord Equality Check
+  {
+    auto t0 = nvfuser::State(0, StateType::Tensor);
+    auto s1 = nvfuser::State(1, StateType::Scalar);
+    auto out = nvfuser::State(2, StateType::Tensor);
+    std::unique_ptr<RecordFunctor> test_record1(
+        new OpRecord<Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*>(
+            {t0, s1},
+            {out},
+            "ops.mul",
+            static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>(
+                Nvf::mul)));
+    std::unique_ptr<RecordFunctor> test_record2(
+        new OpRecord<Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*>(
+            {t0, s1},
+            {out},
+            "ops.mul",
+            static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>(
+                Nvf::mul)));
+    std::unique_ptr<RecordFunctor> test_record3(
+        new OpRecord<Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*>(
+            {t0, s1},
+            {out},
+            "ops.mul",
+            static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>(
+                Nvf::mul)));
+
+    EXPECT_TRUE(*test_record1 == *test_record2);
+    EXPECT_TRUE(*test_record1 == *test_record3);
+    EXPECT_TRUE(*test_record2 == *test_record3);
+  }
+
+  // CastOpRecord Equality Check
+  {
+    auto t0 = nvfuser::State(0, StateType::Tensor);
+    auto out = nvfuser::State(1, StateType::Tensor);
+    std::unique_ptr<RecordFunctor> test_record1(
+        new CastOpRecord<Nvf::TensorView*, Nvf::TensorView*>(
+            {t0},
+            {out},
+            "ops.cast",
+            static_cast<Nvf::TensorView* (*)(Nvf::DataType, Nvf::TensorView*)>(
+                Nvf::castOp),
+            Nvf::DataType::Half));
+    std::unique_ptr<RecordFunctor> test_record2(
+        new CastOpRecord<Nvf::TensorView*, Nvf::TensorView*>(
+            {t0},
+            {out},
+            "ops.cast",
+            static_cast<Nvf::TensorView* (*)(Nvf::DataType, Nvf::TensorView*)>(
+                Nvf::castOp),
+            Nvf::DataType::Half));
+    std::unique_ptr<RecordFunctor> test_record3(
+        new CastOpRecord<Nvf::TensorView*, Nvf::TensorView*>(
+            {t0},
+            {out},
+            "ops.cast",
+            static_cast<Nvf::TensorView* (*)(Nvf::DataType, Nvf::TensorView*)>(
+                Nvf::castOp),
+            Nvf::DataType::Half));
+
+    EXPECT_TRUE(*test_record1 == *test_record2);
+    EXPECT_TRUE(*test_record1 == *test_record3);
+    EXPECT_TRUE(*test_record2 == *test_record3);
+  }
+
+  // ReductionOpRecord Equality Check
+  {
+    auto t0 = nvfuser::State(0, StateType::Tensor);
+    auto out = nvfuser::State(1, StateType::Tensor);
+    std::unique_ptr<RecordFunctor> test_record1(new ReductionOpRecord(
+        {t0},
+        {out},
+        "ops.sum",
+        static_cast<
+            Nvf::
+                TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>(
+            Nvf::sum),
+        {0},
+        false,
+        Nvf::DataType::Float));
+    std::unique_ptr<RecordFunctor> test_record2(new ReductionOpRecord(
+        {t0},
+        {out},
+        "ops.sum",
+        static_cast<
+            Nvf::
+                TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>(
+            Nvf::sum),
+        {0},
+        false,
+        Nvf::DataType::Float));
+    std::unique_ptr<RecordFunctor> test_record3(new ReductionOpRecord(
+        {t0},
+        {out},
+        "ops.sum",
+        static_cast<
+            Nvf::
+                TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>(
+            Nvf::sum),
+        {0},
+        false,
+        Nvf::DataType::Float));
+
+    EXPECT_TRUE(*test_record1 == *test_record2);
+    EXPECT_TRUE(*test_record1 == *test_record3);
+    EXPECT_TRUE(*test_record2 == *test_record3);
+  }
+}
+
+} // namespace jit
+} // namespace torch
+#endif // #if defined(USE_CUDA)
diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp
index fcb8187..95c5424 100644
--- a/torch/csrc/jit/codegen/cuda/utils.cpp
+++ b/torch/csrc/jit/codegen/cuda/utils.cpp
@@ -18,6 +18,7 @@
   std::unordered_map<DebugDumpOption, bool> options_map = {
       {DebugDumpOption::FusionIr, false},
       {DebugDumpOption::FusionIrMath, false},
+      {DebugDumpOption::FusionIrPresched, false},
       {DebugDumpOption::KernelIr, false},
       {DebugDumpOption::ComputeAtMap, false},
       {DebugDumpOption::CudaKernel, false},
@@ -37,6 +38,8 @@
       {DebugDumpOption::ParallelDimensions, false},
       {DebugDumpOption::Halo, false},
       {DebugDumpOption::PerfDebugVerbose, false},
+      {DebugDumpOption::PythonDefinition, false},
+      {DebugDumpOption::PythonFrontendDebug, false},
       {DebugDumpOption::TransformPropagator, false},
       {DebugDumpOption::InlinePropagator, false}};
 
@@ -49,6 +52,8 @@
         options_map[DebugDumpOption::FusionIr] = true;
       } else if (token == "fusion_ir_math") {
         options_map[DebugDumpOption::FusionIrMath] = true;
+      } else if (token == "fusion_ir_presched") {
+        options_map[DebugDumpOption::FusionIrPresched] = true;
       } else if (token == "kernel_ir") {
         options_map[DebugDumpOption::KernelIr] = true;
       } else if (token == "ca_map") {
@@ -87,6 +92,10 @@
         options_map[DebugDumpOption::Halo] = true;
       } else if (token == "perf_debug_verbose") {
         options_map[DebugDumpOption::PerfDebugVerbose] = true;
+      } else if (token == "python_definition") {
+        options_map[DebugDumpOption::PythonDefinition] = true;
+      } else if (token == "python_frontend_debug") {
+        options_map[DebugDumpOption::PythonFrontendDebug] = true;
       } else if (token == "transform_propagator") {
         options_map[DebugDumpOption::TransformPropagator] = true;
       } else if (token == "inline_propagator") {
@@ -97,11 +106,12 @@
             "Invalid debug dump option: '",
             token,
             "'\nAvailable options:\n",
-            "\tfusion_ir, fusion_ir_math, kernel_ir, ca_map, cuda_kernel, cuda_full,\n",
-            "\tcuda_to_file, debug_info, launch_param, segmented_fusion, fusion_args,\n",
-            "\tkernel_args, dump_eff_bandwidth, draw_segmented_fusion,\n",
-            "\tscheduler_params, parallel_dimensions, buffer_reuse_verbose,\n",
-            "\tptxas_verbose, halo, segmenter_logging, perf_debug_verbose\n",
+            "\tfusion_ir, fusion_ir_math, fusion_ir_presched, kernel_ir, ca_map,\n",
+            "\tcuda_kernel, cuda_full, cuda_to_file, debug_info, launch_param,\n",
+            "\tsegmented_fusion, fusion_args, kernel_args, dump_eff_bandwidth,\n",
+            "\tdraw_segmented_fusion, scheduler_params, parallel_dimensions,\n",
+            "\tbuffer_reuse_verbose, ptxas_verbose, halo, segmenter_logging,\n",
+            "\tperf_debug_verbose, python_definition, python_frontend_debug,\n",
             "\ttransform_propagator, inline_propagator\n");
       }
       options_view = (end_pos != c10::string_view::npos)
diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h
index 77f5ab8..f0a715b 100644
--- a/torch/csrc/jit/codegen/cuda/utils.h
+++ b/torch/csrc/jit/codegen/cuda/utils.h
@@ -24,6 +24,7 @@
 enum class DebugDumpOption {
   FusionIr, //!< Dump the Fusion IR before lowering
   FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR
+  FusionIrPresched, //!< Dump the Fusion IR before it is scheduled.
   KernelIr, //!< Dump the compiler Kernel IR
   ComputeAtMap, //!< Dump the computeAt map
   CudaKernel, //!< Dump the generated CUDA C++ kernel code
@@ -46,10 +47,12 @@
   Halo, //! Halo information of tensors
   PerfDebugVerbose, //! When running kernels, print verbose information
                     //! associated with what's running
+  PythonDefinition, //! Python Frontend Fusion Definition.
+  PythonFrontendDebug, //! Python Frontend debug information.
   TransformPropagator, //! When running TransformPropagator, print propagation
                        //! path and replay result
-  InlinePropagator, //! When running InlinePropagator, print propagation
-                    //! path and inlining result
+  InlinePropagator //! When running InlinePropagator, print propagation
+                   //! path and inlining result
 };
 
 TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);