blob: 31cc7632fac696f3332a14ea923cb7e1e755f1c5 [file] [log] [blame]
# mypy: ignore-errors
import itertools
import json
import logging
import math
import warnings
warnings.filterwarnings(
"ignore",
message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
)
from dataclasses import dataclass
import numpy as np
import pandas as pd # type: ignore[import-untyped]
from ah_tree import DecisionTree
from scipy.stats import gmean
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from train import AHTrain
log = logging.getLogger(__name__)
DEBUG = True
if DEBUG:
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
ch.setFormatter(formatter)
log.addHandler(ch)
class AHTrainDecisionTree(AHTrain):
def __init__(self):
super().__init__()
def debug_time(self, row, top_k_choices):
choices_feedback = json.loads(row["choice2time"])
timings = sorted(choices_feedback.items(), key=lambda x: x[1])
for choice, time in timings:
result = f"{choice} {time}"
if choice in top_k_choices:
result += " TOPK"
print(result)
def is_unsafe_leaf(self, row, predicted_config, choice2time):
"""
Can be overridden by subclasses to define their own logic for deciding when a leaf is unsafe. Returns a sample
that landed in the leaf, the choice predicted by the tree, and a dictionary that maps each choice to the
execution time. One can for example decide to mark a leaf as unsafe if the predicted choice is 2x slower
than the fastest choice.
If a leaf is unsafe, the learned heuristic will always return 'unsure' if an input lands in that leaf.
"""
return False
def get_unsafe_leaves(self, model, df, feature_columns):
"""
Given a trained decision tree, and a dataframe containing the training data, returns a list of unsafe leaves.
"""
X = df[feature_columns]
y = df["winner"]
leaf_ids = model.apply(X)
unique_leaves = np.unique(leaf_ids)
unsafe_leaves = []
# Iterate over each leaf
for leaf in unique_leaves:
leaf_mask = leaf_ids == leaf
# Get samples that land in this leaf
leaf_X = X[leaf_mask]
predicted_config = model.predict(leaf_X.iloc[[0]])[0]
# For each sample, check if we should mark the leaf as unsafe
for idx, row in leaf_X.iterrows():
choice2time = json.loads(df.loc[idx, "choice2time"])
if self.is_unsafe_leaf(row, predicted_config, choice2time):
unsafe_leaves.append(leaf)
break
return unsafe_leaves
def get_allowed_wrong_prediction_pct(self):
"""
This is used to determine a threshold for when a learned heuristic returns 'unsure'.
If this function returns 0.01, we will set the probability required for the decision tree to return a decision
such that at most 1% of the predictions will be wrong on the validation set.
"""
return 0.01
def get_grid_search_values(self):
"""
Standard values for grid search. Can be overriden.
"""
return {
"max_depth": [5, 6, 7],
"min_samples_leaf": [1, 5, 10, 0.01, 0.05, 0.02],
"criterion": ["gini", "entropy"],
}
def predict(self, model, df, feature_columns):
"""
Returns the predictions, probabilities, and leaf ids for a given dataframe.
"""
predictions = model.predict(df[feature_columns])
proba = model.predict_proba(df[feature_columns])
leaf_ids = model.apply(df[feature_columns])
return predictions, proba, leaf_ids
def ranking_num_choices(self):
# if the heuristic is used for ranking, this function returns the number
# of choices that the heuristic will return
if self.args.ranking is None:
return 5
return self.args.ranking
def train_and_evaluate_models(
self,
datasets,
max_depths,
min_samples_leafs,
criterion_list,
feature_columns,
ranking=False,
):
"""
Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model.
"""
results = []
best_model = None
best_model_safe_proba = 0
best_model_num_correct = 0
best_model_num_wrong = 0
best_model_unsafe_leaves = []
columns = ["set", "crit", "max_depth", "min_samples_leaf"]
metrics_columns = []
for max_depth, min_samples_leaf, criterion in itertools.product(
max_depths, min_samples_leafs, criterion_list
):
print(
f"max_depth={max_depth} min_samples_leaf={min_samples_leaf} criterion={criterion}"
)
model = DecisionTreeClassifier(
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
criterion=criterion,
random_state=42,
)
df_train = datasets["train"]
df_val = datasets["val"]
if ranking:
model.fit(
df_train[feature_columns],
df_train["winner"],
sample_weight=df_train["relative_performance"],
)
else:
model.fit(df_train[feature_columns], df_train["winner"])
model = DecisionTree(model, feature_columns)
if ranking:
model.prune(df_train, "winner", k=self.ranking_num_choices())
unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns)
predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns)
wrong_pct = self.get_allowed_wrong_prediction_pct()
evaluator = DecisionEvaluator(
self,
model,
predictions,
df_val,
proba,
wrong_pct=wrong_pct,
unsafe_leaves=unsafe_leaves,
leaf_ids=leaf_ids,
k=self.ranking_num_choices(),
ranking=ranking,
)
safe_proba = evaluator.get_safe_proba()
print(f"safe_proba={safe_proba}")
def eval(name, df):
if ranking:
# when ranking is enabled, we duplicate each input for each choice that
# is almost as good as the best choice
# we do not want to evaluate the same input multiple times, so we remove duplicates here
df = df[df["winner"] == df["actual_winner"]]
predictions, proba, leaf_ids = self.predict(model, df, feature_columns)
evaluator = DecisionEvaluator(
self,
model,
predictions,
df,
proba,
wrong_pct=wrong_pct,
threshold=safe_proba,
unsafe_leaves=unsafe_leaves,
leaf_ids=leaf_ids,
k=self.ranking_num_choices(),
ranking=ranking,
)
return evaluator.get_results()
for dataset_name, dataset in datasets.items():
eval_result: EvalResults = eval(dataset_name, dataset)
eval_result_metrics = eval_result.to_map()
if dataset_name == "val":
num_correct = eval_result.accuracy.num_correct
num_wrong = eval_result.accuracy.num_wrong
num_total = eval_result.accuracy.total
if num_wrong <= num_total * wrong_pct:
if num_correct > best_model_num_correct:
print(
f"new best model with {num_correct} correct and {num_wrong} wrong"
)
best_model = model
best_model_num_correct = num_correct
best_model_num_wrong = num_wrong
best_model_safe_proba = safe_proba
best_model_unsafe_leaves = unsafe_leaves
result = (dataset_name, criterion, max_depth, min_samples_leaf)
result += tuple(eval_result_metrics.values())
results.append(result)
if len(metrics_columns) == 0:
metrics_columns = list(eval_result_metrics.keys())
columns += metrics_columns
return (
pd.DataFrame(results, columns=columns),
best_model,
best_model_safe_proba,
best_model_unsafe_leaves,
)
def get_test_and_val_size(self):
"""
Returns the size of the test and validation sets.
"""
return (0.15, 0.15)
def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False):
"""
Splits the dataframe into train, val, and test sets.
Also adds other datasets, specified by the user, to the train set.
"""
test_size, val_size = self.get_test_and_val_size()
# Split into train+val and test
df_train_val, df_test = train_test_split(
df, test_size=test_size, random_state=42
)
# Split train+val inputs into train and val
train_val_size = 1 - test_size
df_train, df_val = train_test_split(
df_train_val, test_size=val_size / train_val_size, random_state=42
)
datasets = {"train": df_train, "val": df_val, "test": df_test}
self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking)
return datasets
def export_to_dot(self, best_model, df, feature_columns):
"""
Export a learned decision tree to a dot file.
"""
dot_str = best_model.to_dot()
with open("best_model.dot", "w") as f:
f.write(dot_str)
def get_feature_columns(self, df):
"""
The dataframe contains columns that are not features, such as 'winner', 'speedup' that are only used for
debugging purposes. This function returns the columns that are actually features.
"""
exclude_columns = [
"speedup",
"winner",
"target",
"avail_choices",
"choice2time",
"index",
"actual_winner",
"relative_performance",
]
feature_columns = [col for col in df.columns if col not in exclude_columns]
return feature_columns
def add_training_data(self, df_train, datasets):
return datasets["train"]
def main(
self,
log_path,
other_datasets,
nrows,
heuristic_name,
save_dot=False,
ranking=False,
):
"""
Main function that trains a decision tree and generates a heuristic.
"""
# TODO: Enable apply_filters
(df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
log_path, nrows=nrows, apply_filters=False, add_near_best=ranking
)
self.dummy_col_2_col_val = dummy_col_2_col_val
datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking)
df_train = self.add_training_data(datasets["train"], datasets)
datasets["train"] = df_train
print(datasets["train"]["winner"].value_counts().to_string())
feature_columns = self.get_feature_columns(df)
grid_search_values = self.get_grid_search_values()
max_depths = grid_search_values["max_depth"]
min_samples_leafs = grid_search_values["min_samples_leaf"]
criterion_list = grid_search_values["criterion"]
(
results_df,
best_model,
best_model_safe_proba,
unsafe_leaves,
) = self.train_and_evaluate_models(
datasets,
max_depths,
min_samples_leafs,
criterion_list,
feature_columns,
ranking=ranking,
)
if ranking:
columns_to_keep = [
"set",
"crit",
"max_depth",
"min_samples_leaf",
"total",
"top_k_correct",
"top_k_wrong",
"top_k_unsure",
"wrong_max_speedup_k",
"wrong_gmean_speedup_k",
]
results_df = results_df[columns_to_keep]
# prints results for all models and datasets
print(results_df.to_string())
sort_metric = "top_k_correct" if ranking else "correct"
# prints results grouped by dataset
for set_name in results_df["set"].unique():
dataset_results = results_df[results_df["set"] == set_name]
dataset_results = dataset_results.sort_values(by=sort_metric)
print(dataset_results.to_string() + "\n")
if best_model is not None:
if save_dot:
self.export_to_dot(best_model, df, feature_columns)
self.codegen(
best_model,
metadata,
heuristic_name,
best_model_safe_proba,
dummy_col_2_col_val,
unsafe_leaves,
)
else:
print(
"All learned models have too many wrong predictions, so no heuristic was generated"
)
def get_df(
self,
log_path,
cat_feature2cats=None,
nrows=None,
apply_filters=False,
add_near_best=False,
):
"""
Parses the log file and processes the data into a dataframe that can be used for training.
"""
(df, metadata, features, categorical_features, choices) = self.parse_log(
log_path, nrows
)
def calculate_stats(group):
count = len(group)
has_inf = np.isinf(group["feedback"]).any()
if has_inf:
relative_std = np.inf
median = np.inf
else:
mean = group["feedback"].mean()
std = group["feedback"].std()
relative_std = (std / mean) * 100 if mean != 0 else np.inf
median = group["feedback"].median()
if relative_std > 5:
times = group["feedback"].tolist()
times_str = ", ".join([f"{t:.3f}" for t in sorted(times)])
log.debug("High relative std: %f. times=%s", relative_std, times_str)
return pd.Series(
{
"count": count,
"relative_std": relative_std,
"median_execution_time": median,
}
)
feature_columns = features
stats = (
df.groupby(feature_columns + ["choice"], as_index=False)
.apply(calculate_stats, include_groups=False)
.reset_index()
)
# TODO: We have to be careful with removing certain choices, because if we e.g. remove the winner, the
# heuristic will end up learning wrong things. But, execution times with high variance are also bad
if apply_filters:
# Filter out inputs with less than 3 measurements or high relative std
valid_stats = stats[(stats["count"] >= 3) & (stats["relative_std"] <= 5)]
# Group by input features and count how many valid choices we have for each input
valid_inputs = valid_stats.groupby(feature_columns).filter(
lambda x: len(x) >= 2
)
else:
valid_inputs = stats
# Compute the winner and speedup for each valid input
def get_winner_and_speedup(group):
assert len(group) >= 2, "Need at least 2 choices"
sorted_group = group.sort_values("median_execution_time")
winner = sorted_group.iloc[0]["choice"]
winning_time = sorted_group.iloc[0]["median_execution_time"]
second_best_time = sorted_group.iloc[1]["median_execution_time"]
speedup = second_best_time / winning_time
unique_choices = group["choice"].unique()
choice2time = {}
for row in group.itertuples():
choice2time[row.choice] = row.median_execution_time
assert len(unique_choices) == len(
group
), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
return pd.Series(
{
"winner": winner,
"speedup": speedup,
"avail_choices": unique_choices,
"choice2time": json.dumps(choice2time),
}
)
results = (
valid_inputs.groupby(feature_columns, as_index=False)
.filter(lambda x: len(x) >= 2)
.groupby(feature_columns, as_index=False)
.apply(get_winner_and_speedup, include_groups=False)
.reset_index()
)
def add_near_best_configs(df):
new_rows = []
for index, row in df.iterrows():
dictionary = json.loads(row["choice2time"])
min_value = min(dictionary.values())
for key, value in dictionary.items():
new_row = row.copy()
relative_performance = min_value / value
new_row["relative_performance"] = relative_performance
if relative_performance is None or relative_performance is np.inf:
breakpoint()
new_row["actual_winner"] = row["winner"]
new_row["winner"] = key
if relative_performance >= 0.98:
new_rows.append(new_row)
return pd.DataFrame(new_rows).reset_index(drop=True)
if add_near_best:
results = add_near_best_configs(results)
(results, added_categorical_features) = self.add_new_features(results)
categorical_features += added_categorical_features
(
results,
cat_feature2cats,
dummy_col_2_col_val,
) = self.handle_categorical_features(
cat_feature2cats, categorical_features, results
)
return (results, choices, cat_feature2cats, dummy_col_2_col_val, metadata)
def ranking_always_included_choices(self):
return []
def gen_classes(self, classes, num_spaces):
"""
If classes=['choice1', 'choice2', 'choice3'], then this function returns
the following string:
self.choices.append('choice1')
self.choices.append('choice2')
self.choices.append('choice3')
Used in the generated heuristic to map the index of a choice to its name.
"""
indent = " " * num_spaces
return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes])
def get_default_config(self, row):
"""
Returns the default config for a given sample. The default config could for example be the config that is
the chosen by a current handwritten heuristic. This can for example be used in get_unsafe_leaf to
compare the predicted config with the default config.
"""
return None
def gen_predict_fn_def(self):
"""
Generates the definition of the predict function.
"""
return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:"
def codegen_boilerplate(
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
):
"""
Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
etc.
"""
boiler_plate = f"""# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/{opt_name}/
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)
class {heuristic_name}(LearnedHeuristicDecision):
def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()
{self.gen_precondition(opt_name, shared_memory, device_capa)}
def get_confidence_threshold(self) -> float:
return {threshold}
def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None
def fill_choices(self) -> None:
{self.gen_classes(classes, num_spaces=8)}
def get_name(self) -> str:
return '{opt_name}'"""
return boiler_plate
def add_real_datasets(
self, datasets, other_datasets, cat_feature2cats, ranking=False
):
"""
Adds datasets specified by the user to the datasets dictionary.
"""
if other_datasets:
for name, path in other_datasets:
(df_other, choices, _, _, _) = self.get_df(
path,
cat_feature2cats=cat_feature2cats,
apply_filters=False,
add_near_best=ranking,
)
datasets[name] = df_other
def codegen(
self,
tree,
metadata,
heuristic_name,
threshold,
dummy_col_2_col_val,
unsafe_leaves,
):
lines = []
device_capa = metadata["device_capa"]
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
opt_name = metadata["name"]
lines.append(
self.codegen_boilerplate(
heuristic_name,
opt_name,
threshold,
metadata["shared_memory"],
device_capa_str,
tree.classes_,
)
)
fn_def = f"\n {self.gen_predict_fn_def()}"
lines.append(fn_def)
tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves)
self.write_heuristic_to_file(lines, heuristic_name)
@dataclass
class AccuracyMetrics:
# Number of correct predictions
num_correct: int
# Number of wrong predictions
num_wrong: int
# Number of predictions where model is unsure
num_unsure: int
# Total number of predictions
total: int
def to_map(self):
return {
"correct": self.num_correct,
"wrong": self.num_wrong,
"unsure": self.num_unsure,
"total": self.total,
}
@dataclass
class WrongSpeedupMetrics:
# If the model predicted the wrong choice, this is the maximum speedup of the best choice over the predicted choice
max_speedup: float
# For all wrong predictions, this is the geometric mean of the speedups of the best choices over the predicted choices
gmean_speedup: float
def to_map(self):
return {
"wrong_max_speedup": self.max_speedup,
"wrong_gmean_speedup": self.gmean_speedup,
}
@dataclass
class RankingMetrics:
# Number of predictions where best choice is in top k choices
num_correct: int
# Number of predictions where best choice is not in top k choices
num_wrong: int
# Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which
# is not in top k, is over the best choice in top k)
max_speedup: float
# Geometric mean of speedups of best choice over best choice in top k
gmean_speedup: float
# Number of predictions where model is unsure
unsure: int
def to_map(self):
return {
"top_k_correct": self.num_correct,
"top_k_wrong": self.num_wrong,
"wrong_max_speedup_k": self.max_speedup,
"wrong_gmean_speedup_k": self.gmean_speedup,
"top_k_unsure": self.unsure,
}
@dataclass
class DefaultComparisonMetrics:
# Maximum speedup of predicted choice over default choice
max_speedup: float
# Geometric mean of speedups of predicted choices over default choices
gmean_speedup: float
# Maximum speedup of default choice over predicted choice
max_slowdown: float
# Number of predictions where the predicted choice is not the default choice
non_default_predictions: int
# Number of predictions where the default choice is better than the predicted choice
default_better: bool
def to_map(self):
return {
"max_speedup_over_default": self.max_speedup,
"gmean_speedup_over_default": self.gmean_speedup,
"max_speedup_default_over_heuristic": self.max_slowdown,
"non_default_predictions": self.non_default_predictions,
"default_better": self.default_better,
}
@dataclass
class EvalResults:
accuracy: AccuracyMetrics
speedup: WrongSpeedupMetrics
ranking: RankingMetrics
default_comparison: DefaultComparisonMetrics
def to_map(self):
return {
**self.accuracy.to_map(),
**self.speedup.to_map(),
**self.ranking.to_map(),
**self.default_comparison.to_map(),
}
class DecisionEvaluator:
def __init__(
self,
train,
model,
predictions,
df,
probas,
wrong_pct=0.01,
threshold=0.0,
k=10,
unsafe_leaves=None,
leaf_ids=None,
ranking=False,
) -> None:
self.train = train
self.model = model
self.predictions = predictions
self.df = df
self.probas = probas
self.wrong_pct = wrong_pct
self.threshold = threshold
self.k = k
self.unsafe_leaves = unsafe_leaves
self.leaf_ids = leaf_ids
self.ranking = ranking
self.num_correct = 0
self.num_wrong = 0
self.num_unsure = 0
self.wrong_probas = []
self.speedups_wrong = []
self.num_correct_top_k = 0
self.num_wrong_top_k = 0
self.wrong_speedups_top_k = []
self.top_k_unsure = 0
self.num_non_default_predictions = 0
self.speedups_over_default = []
self.num_default_better = 0
def compute_speedup_over_default(self, default_config, pred, i, predicted_time):
if default_config is not None:
if pred != default_config:
self.num_non_default_predictions += 1
default_time = self.get_time(self.df.iloc[i], default_config)
# TODO: We should keep track of how often this happens
if default_time is not None and not math.isinf(default_time):
speedup_over_default = default_time / predicted_time
if speedup_over_default < 1:
self.num_default_better += 1
self.speedups_over_default.append(speedup_over_default)
else:
log.debug(
"cannot compute speedup over default because default_time=%d",
default_time,
)
def get_time(self, row, choice):
choices_feedback = json.loads(row["choice2time"])
return choices_feedback.get(choice, None)
def top_k_classes(self, model, probas, k, avail_choices):
# Get classes and their corresponding probabilities
classes = model.classes_
class_proba_pairs = list(zip(classes, probas))
# Sort by probability (descending) and filter out zero probabilities
sorted_classes = [
c
for c, p in sorted(zip(classes, probas), key=lambda x: x[1], reverse=True)
if p > 0 and c in avail_choices
]
# Return top k choices
top_k_choices = sorted_classes[:k]
top_k_choices += self.train.ranking_always_included_choices()
top_k_choices = list(dict.fromkeys(top_k_choices))
return top_k_choices
def eval_prediction(
self, avail_choices, leaf_id, pred, true, prob, threshold, default_config, i
):
predicted_time = self.get_time(self.df.iloc[i], pred)
max_prob = max(prob)
if (
leaf_id in self.unsafe_leaves
or pred not in avail_choices
or (max_prob != 1.0 and max_prob <= threshold)
):
self.num_unsure += 1
self.speedups_over_default.append(1.0)
elif pred == true:
self.compute_speedup_over_default(default_config, pred, i, predicted_time)
self.num_correct += 1
else:
self.compute_speedup_over_default(default_config, pred, i, predicted_time)
self.num_wrong += 1
self.wrong_probas.append(max_prob)
best_time = self.get_time(self.df.iloc[i], true)
wrong_speedup = predicted_time / best_time
self.speedups_wrong.append(wrong_speedup)
def eval_ranking_prediction(self, true, top_k_choices, i):
if true in top_k_choices:
self.num_correct_top_k += 1
else:
top_k_choices_times = []
for choice in top_k_choices:
time = self.get_time(self.df.iloc[i], choice)
if time is not None:
top_k_choices_times.append(time)
best_time = self.get_time(self.df.iloc[i], true)
min_time = min(top_k_choices_times, default=None)
if min_time is not None:
speedup = min_time / best_time
self.wrong_speedups_top_k.append(speedup)
self.num_wrong_top_k += 1
else:
self.top_k_unsure += 1
# TODO (AlnisM): print more info (input and choices)
log.debug(
"All top k choices have no time which means all top k are unavailable"
)
def get_safe_proba(self):
return self.get_results(return_safe_proba=True)
def compute_safe_proba(self, num_predictions, wrong_probas, wrong_pct):
wrong_probas.sort()
num_wrong = len(wrong_probas)
allowed_wrong = int(num_predictions * wrong_pct)
if allowed_wrong >= num_wrong:
return 0.0
too_many_wrong = num_wrong - allowed_wrong
idx = min(too_many_wrong, len(wrong_probas) - 1)
return wrong_probas[idx]
def get_results(self, return_safe_proba=False) -> EvalResults:
"""
Custom evaluation function that evaluates a learned decision tree.
"""
y_true = self.df["actual_winner"] if self.ranking else self.df["winner"]
i = 0
for pred, true, prob, leaf_id in zip(
self.predictions, y_true, self.probas, self.leaf_ids
):
avail_choices = self.df["avail_choices"].iloc[i]
top_k_choices = self.top_k_classes(
self.model, prob, k=self.k, avail_choices=avail_choices
)
assert (
true in avail_choices
), f"Best choice {true} not in available choices {avail_choices}"
default_config = self.train.get_default_config(self.df.iloc[i])
self.eval_prediction(
avail_choices,
leaf_id,
pred,
true,
prob,
self.threshold,
default_config,
i,
)
self.eval_ranking_prediction(true, top_k_choices, i)
i += 1
total = len(self.predictions)
if return_safe_proba:
return self.compute_safe_proba(total, self.wrong_probas, self.wrong_pct)
def safe_gmean(x):
return gmean(x) if x else 0
max_speedup = max(self.speedups_wrong, default=0)
gmean_speedup = safe_gmean(self.speedups_wrong)
max_speedup_top_k = max(self.wrong_speedups_top_k, default=0)
gmean_speedup_top_k = safe_gmean(self.wrong_speedups_top_k)
max_speedup_over_default = max(self.speedups_over_default, default=0)
gmean_speedup_over_default = safe_gmean(self.speedups_over_default)
max_slowdown_over_default = min(self.speedups_over_default, default=0)
accuracyMetrics = AccuracyMetrics(
self.num_correct, self.num_wrong, self.num_unsure, total
)
wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup)
rankingMetrics = RankingMetrics(
self.num_correct_top_k,
self.num_wrong_top_k,
max_speedup_top_k,
gmean_speedup_top_k,
self.top_k_unsure,
)
defaultComparisonMetrics = DefaultComparisonMetrics(
max_speedup_over_default,
gmean_speedup_over_default,
max_slowdown_over_default,
self.num_non_default_predictions,
self.num_default_better,
)
return EvalResults(
accuracyMetrics,
wrongSpeedupMetrics,
rankingMetrics,
defaultComparisonMetrics,
)
if __name__ == "__main__":
train = AHTrainDecisionTree()
train.generate_heuristic()