| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from torch.utils._pytree import tree_flatten, tree_unflatten |
| |
| |
| def tree_map_(fn_, pytree): |
| flat_args, _ = tree_flatten(pytree) |
| [fn_(arg) for arg in flat_args] |
| return pytree |
| |
| |
| class PlaceHolder(): |
| def __repr__(self): |
| return '*' |
| |
| |
| def treespec_pprint(spec): |
| leafs = [PlaceHolder() for _ in range(spec.num_leaves)] |
| result = tree_unflatten(leafs, spec) |
| return repr(result) |