Revert "Extend Inductor to support the third-party backend (#100706)" (#106652)

This reverts commit 05bd24bb3548105776cf73226927cbd0ed575c55.

It caused compilation time regression on torchbench, huggingface and dynamic models.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106652
Approved by: https://github.com/davidberard98, https://github.com/voznesenskym
diff --git a/test/inductor/extension_backends/extension_codegen_backend.py b/test/inductor/extension_backends/extension_codegen_backend.py
deleted file mode 100644
index 634a8e2..0000000
--- a/test/inductor/extension_backends/extension_codegen_backend.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from torch._inductor.codegen import cpp, wrapper
-from torch._inductor.scheduler import BaseScheduling
-from torch._inductor.virtualized import V
-
-
-class ExtensionWrapperCodegen(wrapper.WrapperCodeGen):
-    def __init__(self):
-        super().__init__()
-
-
-class ExtensionScheduling(BaseScheduling):
-    def __init__(self, scheduler):
-        self.scheduler = scheduler
-        self._scheduling = cpp.CppScheduling(scheduler)
-
-    def can_fuse_vertical(self, node1, node2):
-        return True
-
-    def can_fuse_horizontal(self, node1, node2):
-        return True
-
-    def group_fn(self, sizes):
-        return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
-
-    def codegen_template(self, template_node, epilogue_nodes):
-        pass
-
-    def codegen_nodes(self, nodes):
-        self._scheduling.codegen_nodes(nodes)
-
-    def codegen_sync(self):
-        pass
-
-    def flush(self):
-        self._scheduling.flush()
diff --git a/test/inductor/extension_backends/extension_device.cpp b/test/inductor/extension_backends/extension_device.cpp
deleted file mode 100644
index c130add..0000000
--- a/test/inductor/extension_backends/extension_device.cpp
+++ /dev/null
@@ -1,190 +0,0 @@
-#include <c10/core/impl/alloc_cpu.h>
-#include <c10/core/Allocator.h>
-
-#include <torch/csrc/Device.h>
-#include <c10/core/impl/DeviceGuardImplInterface.h>
-#include <c10/macros/Macros.h>
-#include <torch/extension.h>
-
-#include <ATen/native/cpu/Loops.h>
-#include <ATen/native/DispatchStub.h>
-#include <ATen/native/Resize.h>
-#include <ATen/EmptyTensor.h>
-#include <ATen/core/GeneratorForPrivateuseone.h>
-
-static uint64_t op_counter = 0;
-static uint64_t last_saved_value = 0;
-
-// register guard
-namespace at {
-namespace detail {
-
-C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
-
-}} // namespace at::detail
-
-// basic dummy add function
-at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
-  op_counter += 1;
-  // Since this custom device is just for testing, not bothering to implement kernels.
-  return at::empty(self.sizes(), self.options());
-}
-
-// basic dummy mul function
-at::Tensor custom_mul_Tensor(const at::Tensor & self, const at::Tensor & other) {
-  op_counter += 1;
-  // Since this custom device is just for testing, not bothering to implement kernels.
-  return at::empty(self.sizes(), self.options());
-}
-
-// basic dummy eq function: Only support CPU
-at::Tensor custom_to_device(
-    const at::Tensor & self,
-    at::Device device,
-    at::ScalarType dtype,
-    bool non_blocking,
-    bool copy,
-    c10::optional<at::MemoryFormat> memory_format) {
-  TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
-  TORCH_CHECK(device.is_cpu() || device.type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
-  // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
-  TORCH_CHECK(self.scalar_type() == dtype);
-  TORCH_CHECK(self.is_contiguous());
-
-  op_counter += 1;
-  if (device != at::DeviceType::CPU) {
-    return at::empty(self.sizes(), self.options());
-  }
-
-  auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, false, memory_format);
-  memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes());
-  // Since this custom device is just for testing, not bothering to implement kernels.
-  return out;
-}
-
-
-// A dummy allocator for our custom device, that secretly uses the CPU
-struct DummyCustomAllocator final : at::Allocator {
-  DummyCustomAllocator() = default;
-  at::DataPtr allocate(size_t nbytes) const override {
-    void* data = c10::alloc_cpu(nbytes);
-    return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
-  }
-
-  static void ReportAndDelete(void* ptr) {
-    if (!ptr) {
-      return;
-    }
-    c10::free_cpu(ptr);
-  }
-
-  at::DeleterFnPtr raw_deleter() const override {
-    return &ReportAndDelete;
-  }
-};
-
-// Register our dummy allocator
-static DummyCustomAllocator global_custom_alloc;
-REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
-
-at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
-  TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows dummy device.");
-  TORCH_CHECK(self.is_contiguous());
-  TORCH_CHECK(self.scalar_type() == c10::ScalarType::Float);
-
-  op_counter += 1;
-  auto _data = static_cast<float*>(self.mutable_data_ptr());
-  for (size_t idx = 0; idx < self.numel(); idx++) {
-    _data[idx] = value.toFloat();
-  }
-
-  return self;
-}
-
-// basic dummy copy_() function, so we can copy from the custom device to/from CPU
-at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
-  TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
-  TORCH_CHECK(dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device.");
-
-  // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
-  TORCH_CHECK(self.sizes() == dst.sizes());
-  TORCH_CHECK(self.scalar_type() == dst.scalar_type());
-  TORCH_CHECK(self.is_contiguous() && dst.is_contiguous());
-
-  op_counter += 1;
-  std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), self.storage().nbytes());
-  return dst;
-}
-
-at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt, c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, c10::optional<bool> pin_memory_opt) {
-  op_counter += 1;
-
-  constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
-  auto dtype = c10::dtype_or_default(dtype_opt);
-  return  at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
-}
-
-// This macro does the heavy lifting.
-// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
-// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
-// Later in this file, we map a custom device to the PrivateUse1 device type,
-// which allows user code that puts a tensor on your custom_device to eventually get plumbed
-// into the kernels registered here.
-//
-// This macro registers your kernels to the PyTorch Dispatcher.
-// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
-TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
-  m.impl("add.Tensor", &custom_add_Tensor);
-  m.impl("mul.Tensor", &custom_mul_Tensor);
-  m.impl("to.Device", &custom_to_device);
-  m.impl("fill_.Scalar", &custom_fill__scalar);
-  m.impl("_copy_from", &custom__copy_from);
-  m.impl("empty_strided", &custom_empty_strided);
-}
-
-// This basic implementation doesn't bother dealing with different device indices
-// (e.g. custom_device:0 vs. custom_device:1).
-// We could do that by letting the user pass in a device index in our exposed device function.
-// Note that if you do that, you'll also need to register a device guard to core.
-// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
-c10::Device get_custom_device() {
-  return c10::Device(c10::DeviceType::PrivateUse1, 0);
-}
-
-bool custom_op_called() {
-  bool called = false;
-  if (op_counter > last_saved_value) {
-    called = true;
-    last_saved_value = op_counter;
-  }
-  return called;
-}
-
-class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
-public:
-  // Constructors
-  PrivateGeneratorImpl(c10::DeviceIndex device_index) {
-    device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
-    key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
-  }
-  ~PrivateGeneratorImpl() override = default;
-};
-
-// this is used to register generator
-at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
-  return at::make_generator<PrivateGeneratorImpl>(device_index);
-}
-
-void register_generator() {
-  REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
-}
-
-// Here, we're exposing a custom device object that corresponds to our custom backend.
-// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
-// that's implemented in C++.
-// The implementation in this file maps directly to the `PrivateUse1` device type.
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def("custom_device", &get_custom_device, "get custom device object");
-    m.def("custom_op_called", &custom_op_called, "check if our custom function was called");
-    m.def("register_generator", &register_generator, "register generator for custom device");
-}
diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py
deleted file mode 100644
index 139ddbd..0000000
--- a/test/inductor/test_extension_backend.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Owner(s): ["module: inductor"]
-import os
-import shutil
-import sys
-import unittest
-
-import torch
-import torch._dynamo
-import torch.utils.cpp_extension
-from torch._C import FileCheck
-from torch._inductor import metrics
-from torch._inductor.codegen.common import (
-    get_scheduling_for_device,
-    get_wrapper_codegen_for_device,
-    register_backend_for_device,
-)
-from torch.testing._internal.common_utils import IS_MACOS
-
-try:
-    from .extension_backends.extension_codegen_backend import (
-        ExtensionScheduling,
-        ExtensionWrapperCodegen,
-    )
-except ImportError:
-    from extension_backends.extension_codegen_backend import (
-        ExtensionScheduling,
-        ExtensionWrapperCodegen,
-    )
-
-try:
-    try:
-        from . import test_torchinductor
-    except ImportError:
-        import test_torchinductor
-except unittest.SkipTest:
-    if __name__ == "__main__":
-        sys.exit(0)
-    raise
-
-
-run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
-TestCase = test_torchinductor.TestCase
-
-
-def remove_build_path():
-    if sys.platform == "win32":
-        # Not wiping extensions build folder because Windows
-        return
-    default_build_root = torch.utils.cpp_extension.get_default_build_root()
-    if os.path.exists(default_build_root):
-        shutil.rmtree(default_build_root, ignore_errors=True)
-
-
-class ExtensionBackendTests(TestCase):
-    module = None
-
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        # Build Extension
-        remove_build_path()
-        source_file_path = os.path.dirname(os.path.abspath(__file__))
-        source_file = os.path.join(
-            source_file_path, "extension_backends/extension_device.cpp"
-        )
-        cls.module = torch.utils.cpp_extension.load(
-            name="extension_device",
-            sources=[
-                str(source_file),
-            ],
-            extra_cflags=["-g"],
-            verbose=True,
-        )
-
-    @classmethod
-    def tearDownClass(cls):
-        cls._stack.close()
-        super().tearDownClass()
-
-        remove_build_path()
-
-    def setUp(self):
-        torch._dynamo.reset()
-        super().setUp()
-
-        # cpp extensions use relative paths. Those paths are relative to
-        # this file, so we'll change the working directory temporarily
-        self.old_working_dir = os.getcwd()
-        os.chdir(os.path.dirname(os.path.abspath(__file__)))
-        assert self.module is not None
-
-    def tearDown(self):
-        super().tearDown()
-        torch._dynamo.reset()
-
-        # return the working directory (see setUp)
-        os.chdir(self.old_working_dir)
-
-    def test_open_device_registration(self):
-        torch.utils.rename_privateuse1_backend("extension_device")
-
-        register_backend_for_device(
-            "extension_device", ExtensionScheduling, ExtensionWrapperCodegen
-        )
-        self.assertTrue(
-            get_scheduling_for_device("extension_device") == ExtensionScheduling
-        )
-        self.assertTrue(
-            get_wrapper_codegen_for_device("extension_device")
-            == ExtensionWrapperCodegen
-        )
-
-        self.assertFalse(self.module.custom_op_called())
-        device = self.module.custom_device()
-        x = torch.empty(2, 16).to(device=device).fill_(1)
-        self.assertTrue(self.module.custom_op_called())
-        y = torch.empty(2, 16).to(device=device).fill_(2)
-        z = torch.empty(2, 16).to(device=device).fill_(3)
-        ref = torch.empty(2, 16).fill_(5)
-
-        self.assertTrue(x.device == device)
-        self.assertTrue(y.device == device)
-        self.assertTrue(z.device == device)
-
-        def fn(a, b, c):
-            return a * b + c
-
-        metrics.reset()
-        opt_fn = torch.compile()(fn)
-        code = run_and_get_cpp_code(opt_fn, x, y, z)
-        FileCheck().check("void kernel").check("loadu").check("extension_device").run(
-            code
-        )
-        opt_fn(x, y, z)
-        res = opt_fn(x, y, z)
-        self.assertEqual(ref, res.to(device="cpu"))
-
-
-if __name__ == "__main__":
-    from torch._dynamo.test_case import run_tests
-    from torch.testing._internal.inductor_utils import HAS_CPU
-
-    if HAS_CPU and not IS_MACOS:
-        run_tests(needs="filelock")
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 3f59c59..a447884 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -38,46 +38,6 @@
 TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
 SizeArg = namedtuple("SizeArg", ["name", "expr"])
 
-DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"])
-device_codegens: typing.Dict[str, DeviceCodegen] = {}
-
-
-# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
-# For any new backend looking to integrate with Inductor, customization of these two main
-# parts are necessary to generate its specific code.
-#
-# Kernel code generation is determined by different Scheduling. Consequently, a new
-# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
-# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
-#
-# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
-# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
-# and override specific member functions to create backend-specific Python wrapper code.
-#
-# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
-# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
-# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
-# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
-# register_backend_for_device, to equip a new backend at runtime.
-#
-# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
-# This backend can be used as a reference:
-# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
-def register_backend_for_device(
-    device: str, device_scheduling: type, device_wrapper_codegen: type
-):
-    device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)
-
-
-def get_scheduling_for_device(device: str):
-    return device_codegens[device].scheduling if device in device_codegens else None
-
-
-def get_wrapper_codegen_for_device(device: str):
-    return (
-        device_codegens[device].wrapper_codegen if device in device_codegens else None
-    )
-
 
 def index_prevent_reordering(index: typing.List[sympy.Expr], index_vars, sizes):
     from ..ir import FlexibleLayout
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 9b34ea3..62a9e5e 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -23,7 +23,7 @@
 from .. import codecache, config, ir, metrics
 from ..codegen.wrapper import WrapperCodeGen
 from ..optimize_indexing import range_expressable_in_32_bits
-from ..scheduler import BaseScheduling, SchedulerNode
+from ..scheduler import SchedulerNode
 from ..utils import (
     cache_on_self,
     get_fused_kernel_name,
@@ -2699,7 +2699,7 @@
         self.codegen_loops_impl(self.loop_nest, code, worksharing)
 
 
-class CppScheduling(BaseScheduling):
+class CppScheduling:
     def __init__(self, scheduler):
         self.scheduler = scheduler
         self.get_kernel_group()
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 0664cf6..d9ed9c4 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -22,7 +22,6 @@
 from ..dependencies import MemoryDep, StarDep
 from ..ir import ReductionHint
 from ..optimize_indexing import indexing_dtype_strength_reduction
-from ..scheduler import BaseScheduling
 from ..triton_heuristics import AutotuneHint
 from ..utils import (
     DeferredLineBase,
@@ -38,6 +37,7 @@
 )
 from ..virtualized import ops, V
 from ..wrapper_benchmark import get_kernel_category_by_source_code
+
 from .common import (
     CSEVariable,
     DeferredLine,
@@ -1768,7 +1768,6 @@
         triton_meta = {
             "signature": dict(enumerate(map(signature_of, signature))),
             "device": V.graph.scheduler.current_device.index,
-            "device_type": V.graph.scheduler.current_device.type,
             "constants": {},
             "mutated_arg_names": mutated_args,
             "autotune_hints": set(self.autotune_hints),
@@ -1992,7 +1991,7 @@
         return TritonCSEVariable(*args, **kwargs)
 
 
-class TritonScheduling(BaseScheduling):
+class TritonScheduling:
     def __init__(self, scheduler):
         self.scheduler = scheduler
 
diff --git a/torch/_inductor/codegen/triton_foreach.py b/torch/_inductor/codegen/triton_foreach.py
index c9c4049..f441bcd 100644
--- a/torch/_inductor/codegen/triton_foreach.py
+++ b/torch/_inductor/codegen/triton_foreach.py
@@ -88,7 +88,6 @@
         triton_meta = {
             "signature": dict(enumerate(map(signature_of, signature))),
             "device": V.graph.scheduler.current_device.index,
-            "device_type": V.graph.scheduler.current_device.type,
             "constants": {},
         }
         triton_meta["configs"] = [config_of(signature)]
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 7277a54..81b5f58 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -410,13 +410,13 @@
     def next_kernel_suffix(self):
         return f"{next(self._names_iter)}"
 
-    def codegen_device_guard_enter(self, device_idx):
+    def codegen_cuda_device_guard_enter(self, device_idx):
         self.writeline(
             EnterCudaDeviceContextManagerLine(device_idx, self.first_device_guard)
         )
         self.first_device_guard = False
 
-    def codegen_device_guard_exit(self):
+    def codegen_cuda_device_guard_exit(self):
         self.writeline(ExitCudaDeviceContextManagerLine())
 
     def generate_return(self, output_refs):
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 51fb2fa..8feb008 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -26,11 +26,6 @@
 from torch.utils._mode_utils import no_dispatch
 
 from . import config, ir, metrics
-from .codegen.common import (
-    get_scheduling_for_device,
-    get_wrapper_codegen_for_device,
-    register_backend_for_device,
-)
 from .codegen.wrapper import CppWrapperCodeGen, CudaWrapperCodeGen, WrapperCodeGen
 from .exc import (
     LoweringException,
@@ -150,17 +145,6 @@
         stride = [sympy.Integer(i) for i in ex.stride()]
         return size, stride
 
-    def init_backend_registration(self):
-        if get_scheduling_for_device("cpu") is None:
-            from .codegen.cpp import CppScheduling
-
-            register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
-
-        if get_scheduling_for_device("cuda") is None:
-            from .codegen.triton import TritonScheduling
-
-            register_backend_for_device("cuda", TritonScheduling, WrapperCodeGen)
-
     def __init__(
         self,
         gm: torch.fx.GraphModule,
@@ -230,7 +214,6 @@
         )  # This is the linemap used by the profiler to mark custom compiled kernels getting run
         # Used if lowering encounters cases where cudagraphs are not supported
         self.disable_cudagraphs = False
-        self.init_backend_registration()
 
     @staticmethod
     def decide_layout_opt(gm) -> bool:
@@ -884,20 +867,7 @@
                 )
                 return
 
-        device_types = self.device_types.copy()
-        # In terms of some operations that don't have input tensors, we need to
-        # check the deivce of the buffers.
-        for buffer in self.buffers:
-            device_types.add(buffer.get_device().type)
-        device_types.discard("cpu")
-        # TODO(Eikan): Only support mixing cpu and other device now.
-        assert len(device_types) <= 1, "Does not support mixing {}".format(
-            "+".join(device_types)
-        )
-        only_cpu = len(device_types) == 0
-        device_type = "cpu" if only_cpu else device_types.pop()
-        wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type)
-        self.wrapper_code = wrapper_code_gen_cls()
+        self.wrapper_code = WrapperCodeGen()
 
     def codegen(self):
         from .scheduler import Scheduler
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index f9d43c9..aea91be 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -14,7 +14,6 @@
 from torch._dynamo.utils import dynamo_timed
 
 from . import config, dependencies, ir, metrics
-from .codegen.common import get_scheduling_for_device
 from .dependencies import StarDep, WeakDep
 from .sizevars import SimplifyIndexing
 from .utils import cache_on_self, cmp, free_symbol_has, has_triton
@@ -1471,23 +1470,27 @@
         V.graph.device_types.add(device.type)
         V.graph.add_device_idx(device.index)
 
-        device_scheduling = get_scheduling_for_device(device.type)
-        if device_scheduling is None:
+        if device.type == "cpu":
+            from .codegen.cpp import CppScheduling
+
+            return CppScheduling(self)
+        elif device.type == "cuda":
+            if not has_triton():
+                device_props = torch.cuda.get_device_properties(device)
+                if device_props.major < 7:
+                    raise RuntimeError(
+                        f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}"  # noqa: B950
+                    )
+                else:
+                    raise RuntimeError(
+                        "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton"  # noqa: B950
+                    )
+            from .codegen.triton import TritonScheduling
+
+            return TritonScheduling(self)
+        else:
             raise RuntimeError(f"Unsupported device type: {device.type}")
 
-        if device.type == "cuda" and not has_triton():
-            device_props = torch.cuda.get_device_properties(device)
-            if device_props.major < 7:
-                raise RuntimeError(
-                    f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}"  # noqa: B950
-                )
-            else:
-                raise RuntimeError(
-                    "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton"  # noqa: B950
-                )
-
-        return device_scheduling(self)
-
     def get_backend(self, device: torch.device):
         if device not in self.backends:
             self.backends[device] = self.create_backend(device)
@@ -1521,11 +1524,13 @@
                 if device != self.current_device:
                     if device.type == "cuda":
                         if self.current_device and self.current_device.type == "cuda":
-                            V.graph.wrapper_code.codegen_device_guard_exit()
+                            V.graph.wrapper_code.codegen_cuda_device_guard_exit()
                         assert device.index is not None, "device should have an index"
-                        V.graph.wrapper_code.codegen_device_guard_enter(device.index)
+                        V.graph.wrapper_code.codegen_cuda_device_guard_enter(
+                            device.index
+                        )
                     elif self.current_device and self.current_device.type == "cuda":
-                        V.graph.wrapper_code.codegen_device_guard_exit()
+                        V.graph.wrapper_code.codegen_cuda_device_guard_exit()
                     self.current_device = device
 
             self.buffer_names_to_free.update(node.last_usage)
@@ -1549,52 +1554,3 @@
             self.available_buffer_names.update(node.get_names())
 
         self.flush()
-
-
-class BaseScheduling:
-    def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
-        """
-        Check whether node1 and node2 can be vertically fused or not.
-        """
-        raise NotImplementedError()
-
-    def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
-        """
-        Check whether node1 and node2 can be horizontally fused or not.
-        """
-        raise NotImplementedError()
-
-    def group_fn(self, sizes):
-        """
-        Process the iteration sizes in case a transformation needs to be applied.
-        """
-        raise NotImplementedError()
-
-    def codegen_template(
-        self, template_node: BaseSchedulerNode, epilogue_nodes: List[BaseSchedulerNode]
-    ):
-        """
-        Given a template node, generate a kernel.
-
-        This function is only available for triton now. If the third-party backend behaves as a sub-class
-        of TritonScheduling, it can override it or reuse it.
-        """
-        raise NotImplementedError()
-
-    def codegen_nodes(self, nodes: List[BaseSchedulerNode]):
-        """
-        Generate a kernel given a list of pre-fused nodes.
-        """
-        raise NotImplementedError()
-
-    def codegen_sync(self):
-        """
-        Generate synchronization code for the kernel. This method depends on the hardware characteristics.
-        """
-        raise NotImplementedError()
-
-    def flush(self):
-        """
-        Flush the generated kernel and python wrapper code to the source code file.
-        """
-        raise NotImplementedError()
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index f7ca06b..3b128e5 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -116,7 +116,6 @@
         triton_meta = {
             "signature": dict(enumerate(map(signature_of, signature))),
             "device": V.graph.scheduler.current_device.index,
-            "device_type": V.graph.scheduler.current_device.type,
             "constants": {},
         }
         triton_meta["configs"] = [config_of(signature)]