blob: 57ac4e1d1cc691d55b9e4419529c24e6c77a70d0 [file] [log] [blame]
# mypy: ignore-errors
from train import AHTrain
from torch._inductor.fx_passes.pad_mm import pad_mm_operations
class AHTrainPadMM(AHTrain):
def __init__(self):
super().__init__()
def add_new_features(self, results):
ops = pad_mm_operations()
for op in ops:
results[op.name] = results.apply(op.func, axis=1)
added_categorical_features = [op.name for op in ops if op.is_categorical]
return (results, added_categorical_features)
if __name__ == "__main__":
train = AHTrainPadMM()
train.generate_heuristic()