blob: d6e0b799ec6e11b3003069e70c53981a07ba03c9 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
import copy
import torch
import torch.nn as nn
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.distributed._shard import shard_module
from torch.distributed._shard.sharding_plan import ShardingPlan
from torch.distributed._shard.sharder import Sharder
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
# a simple collection of embedding bag implementation
class CustomEmbeddingBagCollection(nn.Module):
def __init__(self, num_bags, num_embeddings_per_bag, num_dims):
super().__init__()
self.num_bags = num_bags
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
for i in range(num_bags):
self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag(
num_embeddings_per_bag,
num_dims,
mode="sum")
def forward(self, inputs):
outputs = []
for bag in self.embedding_bags.values():
outputs.append(bag(inputs))
return torch.cat(outputs)
# a simple sharded version of EBC
class CustomShardedEBC(nn.Module):
def __init__(self, ebc, split_idx, specs):
super().__init__()
self.split_idx = split_idx
row_spec, col_spec = specs
# create embedding bags base on the spec
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
assert self.split_idx < ebc.num_bags
for i in range(ebc.num_bags):
bag_key = f"embedding_bag_{i}"
if i < self.split_idx:
shard_module(ebc, plan=ShardingPlan(plan={f"embedding_bags.{bag_key}.weight": row_spec}))
else:
shard_module(ebc, plan=ShardingPlan(plan={f"embedding_bags.{bag_key}.weight": col_spec}))
self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
class CustomSharder(Sharder):
def __init__(self, devices, split_sharding_idx):
self.devices = devices
self.split_sharding_idx = split_sharding_idx
self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices)
self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices)
def shard(self, ebc: nn.Module) -> nn.Module:
if not isinstance(ebc, CustomEmbeddingBagCollection):
raise RuntimeError("The custom sharder only supports CustomEmbeddingBagCollection")
return CustomShardedEBC(ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec))
class TestCustomSharder(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_custom_sharder(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.ebc = CustomEmbeddingBagCollection(10, 10, 8)
def forward(self, inputs):
return self.ebc(inputs)
custom_sharder = CustomSharder(
devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
split_sharding_idx=TEST_GPU_NUM // 2
)
sharding_plan = ShardingPlan(
plan={
"ebc": custom_sharder,
})
local_model = MyModule().cuda(self.rank)
sharded_model = copy.deepcopy(local_model)
# shard the module with the provided sharding plan
shard_module(sharded_model, sharding_plan)
# check to make sure the module already been sharded
emb_bags = sharded_model.ebc.embedding_bags
self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor))
self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor))
self.assertEqual(emb_bags["embedding_bag_0"].weight.sharding_spec(), custom_sharder.rowwise_spec)
self.assertEqual(emb_bags["embedding_bag_9"].weight.sharding_spec(), custom_sharder.colwise_spec)
# make sure we can run sharded computation and compare outputs
# with the local model version
input = torch.arange(8).reshape((2, 4)).cuda(self.rank)
local_output = local_model(input)
sharded_output = sharded_model(input)
self.assertEqual(local_output, sharded_output)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_custom_sharder_errors(self):
custom_sharder = CustomSharder(
devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
split_sharding_idx=TEST_GPU_NUM // 2
)
sharding_plan = ShardingPlan(
plan={
"": custom_sharder,
})
sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)
with self.assertRaisesRegex(
KeyError, "path must not be empty for custom sharder!"
):
# shard the module with the provided sharding plan
shard_module(sharded_model, sharding_plan)
# test conflicted sharding plan
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"])
sharding_plan = ShardingPlan(
plan={
"embedding_bags.embedding_bag_0.weight": spec,
"embedding_bags": custom_sharder,
})
with self.assertRaisesRegex(
RuntimeError, "should not conflict with the submodule tree"
):
# shard the module with the provided sharding plan
shard_module(sharded_model, sharding_plan)