AutoHeuristic: Support ranking/pruning choices (#131705)

This PR adds support in train_decision if one wants to learn a heuristic for ranking. The main idea is that the user has to provide a number of choices the heuristic should return. I added a way to prune the learned decision tree such that it always returns the number of choices provided by the user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131705
Approved by: https://github.com/eellison
diff --git a/torchgen/_autoheuristic/ah_tree.py b/torchgen/_autoheuristic/ah_tree.py
new file mode 100644
index 0000000..3991ffc
--- /dev/null
+++ b/torchgen/_autoheuristic/ah_tree.py
@@ -0,0 +1,262 @@
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+from sklearn.tree import _tree  # type: ignore[import-untyped]
+
+
+class DecisionTreeNode:
+    def __init__(
+        self,
+        feature: Optional[str] = None,
+        threshold: Optional[float] = None,
+        left: Optional["DecisionTreeNode"] = None,
+        right: Optional["DecisionTreeNode"] = None,
+        class_probs: Any = None,
+        num_samples: int = 0,
+        node_id: int = 0,
+    ) -> None:
+        self.feature = feature
+        self.threshold = threshold
+        self.left = left
+        self.right = right
+        self.class_probs = class_probs
+        self.num_samples = num_samples
+        self.id = node_id
+
+    def is_leaf(self) -> bool:
+        return self.left is None or self.right is None
+
+
+class DecisionTree:
+    """
+    Custom decision tree implementation that mimics some of the sklearn API.
+    The purpose of this class it to be able to perform transformations, such as custom pruning, which
+    does not seem to be easy with sklearn.
+    """
+
+    def __init__(self, sklearn_tree: Any, feature_names: List[str]) -> None:
+        self.feature_names = feature_names
+        self.root = self._convert_sklearn_tree(sklearn_tree.tree_)
+        self.classes_: List[str] = sklearn_tree.classes_
+
+    def _convert_sklearn_tree(
+        self, sklearn_tree: Any, node_id: int = 0
+    ) -> DecisionTreeNode:
+        class_probs = sklearn_tree.value[node_id][0]
+        num_samples = sklearn_tree.n_node_samples[node_id]
+        if sklearn_tree.feature[node_id] != _tree.TREE_UNDEFINED:
+            feature_index = sklearn_tree.feature[node_id]
+            feature = self.feature_names[feature_index]
+            left = self._convert_sklearn_tree(
+                sklearn_tree, sklearn_tree.children_left[node_id]
+            )
+            right = self._convert_sklearn_tree(
+                sklearn_tree, sklearn_tree.children_right[node_id]
+            )
+            return DecisionTreeNode(
+                feature=feature,
+                threshold=sklearn_tree.threshold[node_id],
+                left=left,
+                right=right,
+                class_probs=class_probs,
+                num_samples=num_samples,
+                node_id=node_id,
+            )
+        else:
+            return DecisionTreeNode(
+                class_probs=class_probs, num_samples=num_samples, node_id=node_id
+            )
+
+    def prune(self, df: Any, target_col: str, k: int) -> None:
+        self.root = self._prune_tree(self.root, df, target_col, k)
+
+    def _prune_tree(
+        self, node: DecisionTreeNode, df: Any, target_col: str, k: int
+    ) -> DecisionTreeNode:
+        if node.is_leaf():
+            return node
+
+        left_df = df[df[node.feature] <= node.threshold]
+        right_df = df[df[node.feature] > node.threshold]
+
+        # number of unique classes in the left and right subtrees
+        left_counts = left_df[target_col].nunique()
+        right_counts = right_df[target_col].nunique()
+
+        # for ranking, we want to ensure that we return at least k classes, so if we have less than k classes in the
+        # left or right subtree, we remove the split and make this node a leaf node
+        if left_counts < k or right_counts < k:
+            return DecisionTreeNode(class_probs=node.class_probs)
+
+        assert node.left is not None, "expected left child to exist"
+        node.left = self._prune_tree(node.left, left_df, target_col, k)
+        assert node.right is not None, "expected right child to exist"
+        node.right = self._prune_tree(node.right, right_df, target_col, k)
+
+        return node
+
+    def to_dot(self) -> str:
+        dot = "digraph DecisionTree {\n"
+        dot += '    node [fontname="helvetica"];\n'
+        dot += '    edge [fontname="helvetica"];\n'
+        dot += self._node_to_dot(self.root)
+        dot += "}"
+        return dot
+
+    def _node_to_dot(
+        self, node: DecisionTreeNode, parent_id: int = 0, edge_label: str = ""
+    ) -> str:
+        if node is None:
+            return ""
+
+        node_id = id(node)
+
+        # Format class_probs array with line breaks
+        class_probs_str = self._format_class_probs_array(
+            node.class_probs, node.num_samples
+        )
+
+        if node.is_leaf():
+            label = class_probs_str
+            shape = "box"
+        else:
+            feature_name = f"{node.feature}"
+            label = f"{feature_name} <= {node.threshold:.2f}\\n{class_probs_str}"
+            shape = "oval"
+
+        dot = f'    {node_id} [label="{label}", shape={shape}];\n'
+
+        if parent_id != 0:
+            dot += f'    {parent_id} -> {node_id} [label="{edge_label}"];\n'
+
+        if not node.is_leaf():
+            assert node.left is not None, "expected left child to exist"
+            dot += self._node_to_dot(node.left, node_id, "<=")
+            assert node.right is not None, "expected right child to exist"
+            dot += self._node_to_dot(node.right, node_id, ">")
+
+        return dot
+
+    def _format_class_prob(self, num: float) -> str:
+        if num == 0:
+            return "0"
+        return f"{num:.2f}"
+
+    def _format_class_probs_array(
+        self, class_probs: Any, num_samples: int, max_per_line: int = 5
+    ) -> str:
+        # add line breaks to avoid very long lines
+        flat_class_probs = class_probs.flatten()
+        formatted = [self._format_class_prob(v) for v in flat_class_probs]
+        lines = [
+            formatted[i : i + max_per_line]
+            for i in range(0, len(formatted), max_per_line)
+        ]
+        return f"num_samples={num_samples}\\n" + "\\n".join(
+            [", ".join(line) for line in lines]
+        )
+
+    def predict(self, X: Any) -> Any:
+        predictions = [self._predict_single(x) for _, x in X.iterrows()]
+        return np.array(predictions)
+
+    def predict_proba(self, X: Any) -> Any:
+        return np.array([self._predict_proba_single(x) for _, x in X.iterrows()])
+
+    def _get_leaf(self, X: Any) -> DecisionTreeNode:
+        node = self.root
+        while not node.is_leaf():
+            if X[node.feature] <= node.threshold:
+                assert node.left is not None, "expected left child to exist"
+                node = node.left
+            else:
+                assert node.right is not None, "expected right child to exist"
+                node = node.right
+        return node
+
+    def _predict_single(self, x: Any) -> str:
+        node = self._get_leaf(x)
+        # map index to class name
+        return self.classes_[np.argmax(node.class_probs)]
+
+    def _predict_proba_single(self, x: Any) -> Any:
+        node = self._get_leaf(x)
+        return node.class_probs
+
+    def apply(self, X: Any) -> Any:
+        ids = [self._apply_single(x) for _, x in X.iterrows()]
+        return np.array(ids)
+
+    def _apply_single(self, x: Any) -> int:
+        node = self._get_leaf(x)
+        return node.id
+
+    def codegen(
+        self,
+        dummy_col_2_col_val: Dict[str, Tuple[str, Any]],
+        lines: List[str],
+        unsafe_leaves: List[int],
+    ) -> None:
+        # generates python code for the decision tree
+        def codegen_node(node: DecisionTreeNode, depth: int) -> None:
+            indent = "    " * (depth + 1)
+            if node.is_leaf():
+                lines.append(handle_leaf(node, indent, unsafe_leaves))
+            else:
+                name = node.feature
+                threshold = node.threshold
+                if name in dummy_col_2_col_val:
+                    (orig_name, value) = dummy_col_2_col_val[name]
+                    predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
+                    assert (
+                        threshold == 0.5
+                    ), f"expected threshold to be 0.5 but is {threshold}"
+                else:
+                    predicate = (
+                        f"{indent}if context.get_value('{name}') <= {threshold}:"
+                    )
+                lines.append(predicate)
+                assert node.left is not None, "expected left child to exist"
+                codegen_node(node.left, depth + 1)
+                lines.append(f"{indent}else:")
+                assert node.right is not None, "expected right child to exist"
+                codegen_node(node.right, depth + 1)
+
+        def handle_leaf(
+            node: DecisionTreeNode, indent: str, unsafe_leaves: List[int]
+        ) -> str:
+            """
+            This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
+            will return "unsure" (i.e. None).
+            """
+            if node.id in unsafe_leaves:
+                return f"{indent}return None"
+            class_probas = node.class_probs
+            return f"{indent}return {best_probas_and_indices(class_probas)}"
+
+        def best_probas_and_indices(class_probas: Any) -> str:
+            """
+            Given a list of tuples (proba, idx), this function returns a string in which the tuples are
+            sorted by proba in descending order. E.g.:
+            Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
+            this function returns
+            "[(0.5, 1), (0.3, 0), (0.2, 2)]"
+            """
+            # we generate a list of tuples (proba, idx) sorted by proba in descending order
+            # idx is the index of a choice
+            # we only generate a tuple if proba > 0
+            probas_indices_sorted = sorted(
+                [
+                    (proba, index)
+                    for index, proba in enumerate(class_probas)
+                    if proba > 0
+                ],
+                key=lambda x: x[0],
+                reverse=True,
+            )
+            probas_indices_sorted_str = ", ".join(
+                f"({value:.3f}, {index})" for value, index in probas_indices_sorted
+            )
+            return f"[{probas_indices_sorted_str}]"
+
+        codegen_node(self.root, 1)
diff --git a/torchgen/_autoheuristic/train.py b/torchgen/_autoheuristic/train.py
index 78a16c4..4e8dd33 100644
--- a/torchgen/_autoheuristic/train.py
+++ b/torchgen/_autoheuristic/train.py
@@ -2,7 +2,6 @@
 
 import argparse
 import json
-import sys
 import warnings
 
 import pandas as pd  # type: ignore[import-untyped]
@@ -64,6 +63,15 @@
             action="store_true",
             help="Export heuristic to graphviz dot.",
         )
+        self.parser.add_argument(
+            "--ranking",
+            type=int,
+            default=None,
+            help="""
+                Makes AutoHeuristic learn a heuristic that ranks choices instead of predicting a single choice.
+                The argument is the number of choices the heuristic will provide.
+            """,
+        )
 
     def parse_args(self):
         return self.parser.parse_args()
@@ -87,6 +95,7 @@
             self.args.nrows,
             self.args.heuristic_name,
             self.args.save_dot,
+            self.args.ranking is not None,
         )
 
     def filter_df(self, df):
@@ -138,9 +147,6 @@
             and str(metadata.device_capa) == "{device_capa}"
         )"""
 
-    def handle_leaf(self, tree_, node, indent, unsafe_leaves):
-        pass
-
     def codegen_boilerplate(
         self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
     ):
@@ -149,63 +155,7 @@
     def gen_predict_fn_def(self):
         pass
 
-    def dt_to_python(
-        self,
-        dt,
-        metadata,
-        feature_names,
-        dummy_col_2_col_val,
-        heuristic_name,
-        threshold,
-        unsafe_leaves=None,
-    ):
-        tree_ = dt.tree_
-        feature_name = [
-            feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
-        ]
-
-        lines = []
-        device_capa = metadata["device_capa"]
-        device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
-        opt_name = metadata["name"]
-        lines.append(
-            self.codegen_boilerplate(
-                heuristic_name,
-                opt_name,
-                threshold,
-                metadata["shared_memory"],
-                device_capa_str,
-                dt,
-            )
-        )
-        fn_def = f"\n    {self.gen_predict_fn_def()}"
-        lines.append(fn_def)
-
-        def dt_to_python(node, depth):
-            indent = "    " * (depth + 1)
-            false_predicate = ""
-            if tree_.feature[node] != -2:
-                name = feature_name[node]
-                threshold = tree_.threshold[node]
-                if name in dummy_col_2_col_val:
-                    (orig_name, value) = dummy_col_2_col_val[name]
-                    predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
-                    if threshold != 0.5:
-                        print(f"expected threshold to be 0.5 but is {threshold}")
-                        sys.exit(1)
-                else:
-                    predicate = (
-                        f"{indent}if context.get_value('{name}') <= {threshold}:"
-                    )
-                lines.append(predicate)
-                dt_to_python(tree_.children_left[node], depth + 1)
-                lines.append(f"{indent}else:")
-                dt_to_python(tree_.children_right[node], depth + 1)
-            else:
-                lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
-
-        dt_to_python(0, 1)
-
+    def write_heuristic_to_file(self, lines, heuristic_name):
         output_file = (
             f"../../../torch/_inductor/autoheuristic/artifacts/_{heuristic_name}.py"
         )
diff --git a/torchgen/_autoheuristic/train_decision.py b/torchgen/_autoheuristic/train_decision.py
index f2270d2..cd0b3ff 100644
--- a/torchgen/_autoheuristic/train_decision.py
+++ b/torchgen/_autoheuristic/train_decision.py
@@ -16,6 +16,7 @@
 
 import numpy as np
 import pandas as pd  # type: ignore[import-untyped]
+from ah_tree import DecisionTree
 from scipy.stats import gmean
 from sklearn.model_selection import train_test_split
 from sklearn.tree import DecisionTreeClassifier
@@ -102,8 +103,21 @@
         leaf_ids = model.apply(df[feature_columns])
         return predictions, proba, leaf_ids
 
+    def ranking_num_choices(self):
+        # if the heuristic is used for ranking, this function returns the number
+        # of choices that the heuristic will return
+        if self.args.ranking is None:
+            return 5
+        return self.args.ranking
+
     def train_and_evaluate_models(
-        self, datasets, max_depths, min_samples_leafs, criterion_list, feature_columns
+        self,
+        datasets,
+        max_depths,
+        min_samples_leafs,
+        criterion_list,
+        feature_columns,
+        ranking=False,
     ):
         """
         Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model.
@@ -131,7 +145,20 @@
             )
             df_train = datasets["train"]
             df_val = datasets["val"]
-            model.fit(df_train[feature_columns], df_train["winner"])
+            if ranking:
+                model.fit(
+                    df_train[feature_columns],
+                    df_train["winner"],
+                    sample_weight=df_train["relative_performance"],
+                )
+            else:
+                model.fit(df_train[feature_columns], df_train["winner"])
+
+            model = DecisionTree(model, feature_columns)
+
+            if ranking:
+                model.prune(df_train, "winner", k=self.ranking_num_choices())
+
             unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns)
             predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns)
 
@@ -145,11 +172,18 @@
                 wrong_pct=wrong_pct,
                 unsafe_leaves=unsafe_leaves,
                 leaf_ids=leaf_ids,
+                k=self.ranking_num_choices(),
+                ranking=ranking,
             )
             safe_proba = evaluator.get_safe_proba()
             print(f"safe_proba={safe_proba}")
 
             def eval(name, df):
+                if ranking:
+                    # when ranking is enabled, we duplicate each input for each choice that
+                    # is almost as good as the best choice
+                    # we do not want to evaluate the same input multiple times, so we remove duplicates here
+                    df = df[df["winner"] == df["actual_winner"]]
                 predictions, proba, leaf_ids = self.predict(model, df, feature_columns)
                 evaluator = DecisionEvaluator(
                     self,
@@ -161,6 +195,8 @@
                     threshold=safe_proba,
                     unsafe_leaves=unsafe_leaves,
                     leaf_ids=leaf_ids,
+                    k=self.ranking_num_choices(),
+                    ranking=ranking,
                 )
                 return evaluator.get_results()
 
@@ -202,7 +238,7 @@
         """
         return (0.15, 0.15)
 
-    def prepare_datasets(self, df, other_datasets, cat_feature2cats):
+    def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False):
         """
         Splits the dataframe into train, val, and test sets.
         Also adds other datasets, specified by the user, to the train set.
@@ -219,24 +255,16 @@
             df_train_val, test_size=val_size / train_val_size, random_state=42
         )
         datasets = {"train": df_train, "val": df_val, "test": df_test}
-        self.add_real_datasets(datasets, other_datasets, cat_feature2cats)
+        self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking)
         return datasets
 
     def export_to_dot(self, best_model, df, feature_columns):
         """
         Export a learned decision tree to a dot file.
         """
-        from sklearn import tree
-
-        tree.export_graphviz(
-            best_model,
-            out_file="best_model.dot",
-            feature_names=df[feature_columns].columns,
-            class_names=[str(c) for c in best_model.classes_],
-            filled=True,
-            rounded=True,
-            special_characters=True,
-        )
+        dot_str = best_model.to_dot()
+        with open("best_model.dot", "w") as f:
+            f.write(dot_str)
 
     def get_feature_columns(self, df):
         """
@@ -250,20 +278,36 @@
             "avail_choices",
             "choice2time",
             "index",
+            "actual_winner",
+            "relative_performance",
         ]
         feature_columns = [col for col in df.columns if col not in exclude_columns]
         return feature_columns
 
-    def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False):
+    def add_training_data(self, df_train, datasets):
+        return datasets["train"]
+
+    def main(
+        self,
+        log_path,
+        other_datasets,
+        nrows,
+        heuristic_name,
+        save_dot=False,
+        ranking=False,
+    ):
         """
         Main function that trains a decision tree and generates a heuristic.
         """
         # TODO: Enable apply_filters
         (df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
-            log_path, nrows=nrows, apply_filters=False
+            log_path, nrows=nrows, apply_filters=False, add_near_best=ranking
         )
         print(df["winner"].value_counts())
-        datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats)
+        datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking)
+        df_train = self.add_training_data(datasets["train"], datasets)
+        datasets["train"] = df_train
+
         feature_columns = self.get_feature_columns(df)
         grid_search_values = self.get_grid_search_values()
         max_depths = grid_search_values["max_depth"]
@@ -275,28 +319,44 @@
             best_model_safe_proba,
             unsafe_leaves,
         ) = self.train_and_evaluate_models(
-            datasets, max_depths, min_samples_leafs, criterion_list, feature_columns
+            datasets,
+            max_depths,
+            min_samples_leafs,
+            criterion_list,
+            feature_columns,
+            ranking=ranking,
         )
 
+        if ranking:
+            columns_to_keep = [
+                "set",
+                "total",
+                "top_k_correct",
+                "top_k_wrong",
+                "top_k_unsure",
+                "wrong_max_spdup_k",
+                "wrong_gman_spdup_k",
+            ]
+            results_df = results_df[columns_to_keep]
         # prints results for all models and datasets
         print(results_df.to_string())
 
-        # prints results grouped by dataset
-        for set_name in results_df["set"].unique():
-            dataset_results = results_df[results_df["set"] == set_name]
-            dataset_results = dataset_results.sort_values(by="correct")
-            print(dataset_results.to_string() + "\n")
+        if not ranking:
+            # prints results grouped by dataset
+            for set_name in results_df["set"].unique():
+                dataset_results = results_df[results_df["set"] == set_name]
+                dataset_results = dataset_results.sort_values(by="correct")
+                print(dataset_results.to_string() + "\n")
 
         if best_model is not None:
             if save_dot:
                 self.export_to_dot(best_model, df, feature_columns)
-            self.dt_to_python(
+            self.codegen(
                 best_model,
                 metadata,
-                feature_columns,
-                dummy_col_2_col_val,
                 heuristic_name,
                 best_model_safe_proba,
+                dummy_col_2_col_val,
                 unsafe_leaves,
             )
         else:
@@ -304,7 +364,14 @@
                 "All learned models have too many wrong predictions, so no heuristic was generated"
             )
 
-    def get_df(self, log_path, cat_feature2cats=None, nrows=None, apply_filters=False):
+    def get_df(
+        self,
+        log_path,
+        cat_feature2cats=None,
+        nrows=None,
+        apply_filters=False,
+        add_near_best=False,
+    ):
         """
         Parses the log file and processes the data into a dataframe that can be used for training.
         """
@@ -314,14 +381,19 @@
 
         def calculate_stats(group):
             count = len(group)
-            mean = group["feedback"].mean()
-            std = group["feedback"].std()
-            relative_std = (std / mean) * 100 if mean != 0 else np.inf
+            has_inf = np.isinf(group["feedback"]).any()
+            if has_inf:
+                relative_std = np.inf
+                median = np.inf
+            else:
+                mean = group["feedback"].mean()
+                std = group["feedback"].std()
+                relative_std = (std / mean) * 100 if mean != 0 else np.inf
+                median = group["feedback"].median()
             if relative_std > 5:
                 times = group["feedback"].tolist()
                 times_str = ", ".join([f"{t:.3f}" for t in sorted(times)])
                 log.debug("High relative std: %f. times=%s", relative_std, times_str)
-            median = group["feedback"].median()
             return pd.Series(
                 {
                     "count": count,
@@ -385,6 +457,28 @@
             .reset_index()
         )
 
+        def add_near_best_configs(df):
+            new_rows = []
+
+            for index, row in df.iterrows():
+                dictionary = json.loads(row["choice2time"])
+                min_value = min(dictionary.values())
+
+                for key, value in dictionary.items():
+                    new_row = row.copy()
+                    relative_performance = min_value / value
+                    new_row["relative_performance"] = relative_performance
+                    if relative_performance is None or relative_performance is np.inf:
+                        breakpoint()
+                    new_row["actual_winner"] = row["winner"]
+                    new_row["winner"] = key
+                    if relative_performance >= 0.95:
+                        new_rows.append(new_row)
+
+            return pd.DataFrame(new_rows).reset_index(drop=True)
+
+        if add_near_best:
+            results = add_near_best_configs(results)
         (results, added_categorical_features) = self.add_new_features(results)
         categorical_features += added_categorical_features
 
@@ -409,27 +503,6 @@
         indent = " " * num_spaces
         return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes])
 
-    def best_probas_and_indices(self, class_probas):
-        """
-        Given a list of tuples (proba, idx), this function returns a string in which the tuples are sorted by proba in
-        descending order. E.g.:
-        Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
-        this function returns
-        "[(0.5, 1), (0.3, 0), (0.2, 2)]"
-        """
-        # we generate a list of tuples (proba, idx) sorted by proba in descending order
-        # idx is the index of a choice
-        # we only generate a tuple if proba > 0
-        probas_indices_sorted = sorted(
-            [(proba, index) for index, proba in enumerate(class_probas) if proba > 0],
-            key=lambda x: x[0],
-            reverse=True,
-        )
-        probas_indices_sorted_str = ", ".join(
-            f"({value:.3f}, {index})" for value, index in probas_indices_sorted
-        )
-        return f"[{probas_indices_sorted_str}]"
-
     def get_default_config(self, row):
         """
         Returns the default config for a given sample. The default config could for example be the config that is
@@ -438,17 +511,6 @@
         """
         return None
 
-    def handle_leaf(self, tree_, node, indent, unsafe_leaves):
-        """
-        This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
-        will return "unsure" (i.e. None).
-        """
-        if node in unsafe_leaves:
-            return f"{indent}return None"
-        leaf_num_samples = tree_.n_node_samples[node]
-        class_probas = tree_.value[node][0]
-        return f"{indent}return {self.best_probas_and_indices(class_probas)}"
-
     def gen_predict_fn_def(self):
         """
         Generates the definition of the predict function.
@@ -456,7 +518,7 @@
         return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:"
 
     def codegen_boilerplate(
-        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
+        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
     ):
         """
         Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
@@ -496,23 +558,56 @@
         return None
 
     def fill_choices(self) -> None:
-{self.gen_classes(dt.classes_, num_spaces=8)}
+{self.gen_classes(classes, num_spaces=8)}
 
     def get_name(self) -> str:
         return '{opt_name}'"""
         return boiler_plate
 
-    def add_real_datasets(self, datasets, other_datasets, cat_feature2cats):
+    def add_real_datasets(
+        self, datasets, other_datasets, cat_feature2cats, ranking=False
+    ):
         """
         Adds datasets specified by the user to the datasets dictionary.
         """
         if other_datasets:
             for name, path in other_datasets:
                 (df_other, choices, _, _, _) = self.get_df(
-                    path, cat_feature2cats=cat_feature2cats, apply_filters=False
+                    path,
+                    cat_feature2cats=cat_feature2cats,
+                    apply_filters=False,
+                    add_near_best=ranking,
                 )
                 datasets[name] = df_other
 
+    def codegen(
+        self,
+        tree,
+        metadata,
+        heuristic_name,
+        threshold,
+        dummy_col_2_col_val,
+        unsafe_leaves,
+    ):
+        lines = []
+        device_capa = metadata["device_capa"]
+        device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
+        opt_name = metadata["name"]
+        lines.append(
+            self.codegen_boilerplate(
+                heuristic_name,
+                opt_name,
+                threshold,
+                metadata["shared_memory"],
+                device_capa_str,
+                tree.classes_,
+            )
+        )
+        fn_def = f"\n    {self.gen_predict_fn_def()}"
+        lines.append(fn_def)
+        tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves)
+        self.write_heuristic_to_file(lines, heuristic_name)
+
 
 @dataclass
 class AccuracyMetrics:
@@ -552,6 +647,8 @@
 class RankingMetrics:
     # Number of predictions where best choice is in top k choices
     num_correct: int
+    # Number of predictions where best choice is not in top k choices
+    num_wrong: int
     # Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which
     # is not in top k, is over the best choice in top k)
     max_speedup: float
@@ -563,6 +660,7 @@
     def to_map(self):
         return {
             "top_k_correct": self.num_correct,
+            "top_k_wrong": self.num_wrong,
             "wrong_max_speedup_k": self.max_speedup,
             "wrong_gmean_speedup_k": self.gmean_speedup,
             "top_k_unsure": self.unsure,
@@ -618,9 +716,10 @@
         probas,
         wrong_pct=0.01,
         threshold=0.0,
-        k=3,
+        k=10,
         unsafe_leaves=None,
         leaf_ids=None,
+        ranking=False,
     ) -> None:
         self.train = train
         self.model = model
@@ -632,6 +731,7 @@
         self.k = k
         self.unsafe_leaves = unsafe_leaves
         self.leaf_ids = leaf_ids
+        self.ranking = ranking
 
         self.num_correct = 0
         self.num_wrong = 0
@@ -639,6 +739,7 @@
         self.wrong_probas = []
         self.speedups_wrong = []
         self.num_correct_top_k = 0
+        self.num_wrong_top_k = 0
         self.wrong_speedups_top_k = []
         self.top_k_unsure = 0
         self.num_non_default_predictions = 0
@@ -718,6 +819,7 @@
             if min_time is not None:
                 speedup = min_time / best_time
                 self.wrong_speedups_top_k.append(speedup)
+                self.num_wrong_top_k += 1
             else:
                 self.top_k_unsure += 1
                 # TODO (AlnisM): print more info (input and choices)
@@ -743,7 +845,7 @@
         Custom evaluation function that evaluates a learned decision tree.
         """
 
-        y_true = self.df["winner"]
+        y_true = self.df["actual_winner"] if self.ranking else self.df["winner"]
         i = 0
         for pred, true, prob, leaf_id in zip(
             self.predictions, y_true, self.probas, self.leaf_ids
@@ -790,6 +892,7 @@
         wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup)
         rankingMetrics = RankingMetrics(
             self.num_correct_top_k,
+            self.num_wrong_top_k,
             max_speedup_top_k,
             gmean_speedup_top_k,
             self.top_k_unsure,
diff --git a/torchgen/_autoheuristic/train_regression.py b/torchgen/_autoheuristic/train_regression.py
index 024095e..1fc4873 100644
--- a/torchgen/_autoheuristic/train_regression.py
+++ b/torchgen/_autoheuristic/train_regression.py
@@ -34,7 +34,15 @@
     def __init__(self):
         super().__init__()
 
-    def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False):
+    def main(
+        self,
+        log_path,
+        other_datasets,
+        nrows,
+        heuristic_name,
+        save_dot=False,
+        ranking=False,
+    ):
         """
         Main function that trains a decision tree and generates a heuristic.
         """
@@ -357,6 +365,65 @@
             "wrong_max_ratio": wrong_max_ratio,
         }
 
+    def dt_to_python(
+        self,
+        dt,
+        metadata,
+        feature_names,
+        dummy_col_2_col_val,
+        heuristic_name,
+        threshold,
+        unsafe_leaves=None,
+    ):
+        tree_ = dt.tree_
+        feature_name = [
+            feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
+        ]
+
+        lines = []
+        device_capa = metadata["device_capa"]
+        device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
+        opt_name = metadata["name"]
+        lines.append(
+            self.codegen_boilerplate(
+                heuristic_name,
+                opt_name,
+                threshold,
+                metadata["shared_memory"],
+                device_capa_str,
+                dt,
+            )
+        )
+        fn_def = f"\n    {self.gen_predict_fn_def()}"
+        lines.append(fn_def)
+
+        def dt_to_python(node, depth):
+            indent = "    " * (depth + 1)
+            false_predicate = ""
+            if tree_.feature[node] != -2:
+                name = feature_name[node]
+                threshold = tree_.threshold[node]
+                if name in dummy_col_2_col_val:
+                    (orig_name, value) = dummy_col_2_col_val[name]
+                    predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
+                    assert (
+                        threshold == 0.5
+                    ), f"expected threshold to be 0.5 but is {threshold}"
+                else:
+                    predicate = (
+                        f"{indent}if context.get_value('{name}') <= {threshold}:"
+                    )
+                lines.append(predicate)
+                dt_to_python(tree_.children_left[node], depth + 1)
+                lines.append(f"{indent}else:")
+                dt_to_python(tree_.children_right[node], depth + 1)
+            else:
+                lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
+
+        dt_to_python(0, 1)
+
+        self.write_heuristic_to_file(lines, heuristic_name)
+
     def handle_leaf(self, tree_, node, indent, unsafe_leaves):
         """
         Generates the code for a leaf node. This is just the value predicted by the regression tree.
@@ -368,7 +435,7 @@
         return "def predict(self, context: AHContext) -> float:"
 
     def codegen_boilerplate(
-        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
+        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
     ):
         """
         Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,