Creating a helper function to generate an unique name for an attr in a module (#64970)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64970
Add a helper function to create an unique name for an attr.
This can be used when we want to add a weight to a module.
Test Plan: run CI.
Reviewed By: jfix71
Differential Revision: D30921497
fbshipit-source-id: 598569d107df8b516ff12920a4bef3a42577e987
diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py
index 38ead16..fb9ba2a 100644
--- a/torch/fx/experimental/const_fold.py
+++ b/torch/fx/experimental/const_fold.py
@@ -1,8 +1,8 @@
import operator
-import re
from typing import Dict, Set, List, Optional, Union
import torch.fx
+import torch.fx.experimental.fx_acc.acc_utils as acc_utils
from torch.fx.passes.split_module import split_module
@@ -138,18 +138,8 @@
for i in range(len(const_output_names)):
# Add a suffix to make it easier to tell these were the result of const folding.
name = const_output_names[i] + "__CF"
- # Delete all characters that are illegal in a Python identifier.
- name = re.sub("[^0-9a-zA-Z_]+", "_", name)
- if name[0].isdigit():
- name = f"_{name}"
- # Now make sure it is in fact unique to the module by incrementing suffix value.
- while hasattr(mod_traced, name):
- match = re.match(r"(.*)_(\d+)$", name)
- if match is None:
- name = name + "_1"
- else:
- base, num = match.group(1, 2)
- name = f"{base}_{int(num) + 1}"
+ # Get an unique name for the attr.
+ name = acc_utils.get_unique_attr_name_in_module(mod_traced, name)
const_output_names[i] = name
# Now track the const_output_names to what name is used in the parent graph
diff --git a/torch/fx/experimental/fx_acc/acc_utils.py b/torch/fx/experimental/fx_acc/acc_utils.py
index b331ee6..c0d550f 100644
--- a/torch/fx/experimental/fx_acc/acc_utils.py
+++ b/torch/fx/experimental/fx_acc/acc_utils.py
@@ -2,6 +2,7 @@
import json
import os
from typing import Any, Tuple, Callable, Union, Dict, List, Optional
+import re
import torch
import torch.fx
@@ -137,3 +138,22 @@
model_info_str += f"> {op_str}: {count}\n"
print(model_info_str)
+
+def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
+ """
+ Make sure the name is unique (in a module) and can represents an attr.
+ """
+ # Delete all characters that are illegal in a Python identifier.
+ name = re.sub("[^0-9a-zA-Z_]+", "_", name)
+ if name[0].isdigit():
+ name = f"_{name}"
+ # Now make sure it is in fact unique to the module by incrementing suffix value.
+ while hasattr(mod_traced, name):
+ match = re.match(r"(.*)_(\d+)$", name)
+ if match is None:
+ name = name + "_1"
+ else:
+ base, num = match.group(1, 2)
+ name = f"{base}_{int(num) + 1}"
+
+ return name