Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 65 additions & 23 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,19 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
class _AllToAll(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, async_op=False) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
return output
work = dist.all_to_all_single(output, input, group=group, async_op=async_op)
if async_op:
return output, work
else:
return output

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
return (None, _AllToAll.apply(ctx.group, *grad_output), None)


# einsum rewrites are on par or more performant
Expand Down Expand Up @@ -550,6 +553,7 @@ class MOELayer(Base):
expert (torch.nn.Module):
expert network
"""
d2d_stream = torch.cuda.Stream()

def __init__(self,
gate: Module,
Expand All @@ -572,6 +576,8 @@ def __init__(self,
self.wall_clock_breakdown = False

self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
self.enable_pipelie = True
self.shard_num = 4

if self.use_tutel:
logger.info('Using Tutel optimizations.')
Expand All @@ -586,8 +592,54 @@ def _set_ep_group(self, ep_group):
self.ep_group = ep_group
self.gate._set_ep_group(ep_group)

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
# During multi machine MOE training, alltoall is the communication between machines,
# allgather is the communication within machines. They use different communication links,
# so they can be executed in parallel
# input shape (E,C,M),Shard input in C dim, first execute alltoall on the shard,
# So the allgather of this shard and the alltoall of the next shard are executed in parallel
# A E I M
# A1 E1 I1 M1
# A2 E2 I2 M2
# A3 E3 I3 M3
# A4 E4 I4 M4
def pipeline_alltoall_with_allgather(self, input, shard_dim=1) -> Tensor:
if not self.enable_pipelie:
input = _AllToAll.apply(self.ep_group, input)
input = gather_tokens(input, dim=shard_dim)
return input

assert self.shard_num > 0, f"shard_num must be a positive number,but get is {self.shard_num}"
input_chunks = list(input.chunk(self.shard_num, dim=shard_dim))
world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
dims = list(input.size())
dims[shard_dim] = dims[shard_dim] * world_size
output = torch.empty(dims, device=input.device)
input_gather_dim_len = input.shape[shard_dim]
have_gather_len = 0
works = []
for i in range(len(input_chunks)):
input_chunks[i], work = _AllToAll.apply(self.ep_group, input_chunks[i], True)
works.append(work)

current_stream = torch.cuda.current_stream()
for i in range(len(input_chunks)):
works[i].wait()
# we use dim 0 do allgather and chunk, so we can avoid unnecessary cat in gather_tokens
gather_out = gather_tokens(input_chunks[i], dim=0)
gather_list = gather_out.chunk(world_size, dim=0)
dim_len = gather_list[0].shape[shard_dim]
MOELayer.d2d_stream.wait_stream(current_stream)

for j in range(len(gather_list)):
start = input_gather_dim_len * j + have_gather_len
with torch.cuda.stream(MOELayer.d2d_stream):
torch.narrow(output, shard_dim, start, dim_len).copy_(gather_list[j])
have_gather_len += dim_len

current_stream.wait_stream(MOELayer.d2d_stream)
return output

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
self.timers(MOE_TIMER).start()

Expand All @@ -611,9 +663,6 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)

if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).start()

tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
if tensor_model_world_size > 1:
# If the non-expert is tensor-parallel,
Expand All @@ -628,18 +677,17 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
# an allgather to ensure correctness,
dispatched_input = drop_tokens(dispatched_input, dim=1)

dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)

if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).stop()
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
self.timers(FIRST_ALLTOALL_TIMER).start()

if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
# if both expert and non-expert are tensor-parallel
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again to ensure correctness
dispatched_input = gather_tokens(dispatched_input, dim=1)
dispatched_input = self.pipeline_alltoall_with_allgather(dispatched_input)
else:
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)

if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).stop()
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
Expand All @@ -654,18 +702,12 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).start()

expert_output = _AllToAll.apply(self.ep_group, expert_output)
expert_output = self.pipeline_alltoall_with_allgather(expert_output)

if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).stop()
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)

if tensor_model_world_size > 1:
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again for the tensor-parallel
# non-expert of the next layer.
expert_output = gather_tokens(expert_output, dim=1)

if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else:
Expand Down
89 changes: 89 additions & 0 deletions tests/unit/moe/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import deepspeed
import pytest
from unit.common import DistributedTest
from deepspeed import get_accelerator
from deepspeed.moe.sharded_moe import _AllToAll
from deepspeed.moe.mappings import gather_tokens
from deepspeed.moe.layer import MoE


class MPU():

def __init__(self, tp_world_size):
self.rank = deepspeed.comm.get_rank()
self.world_size = deepspeed.comm.get_world_size()
self.tp_world_size = tp_world_size

for i in range(0, self.world_size, tp_world_size):
ranks = range(i, i + tp_world_size)
group = deepspeed.comm.new_group(ranks)
if self.rank in ranks:
self.tp_group = group

for i in range(0, tp_world_size):
ranks = range(i, self.world_size, tp_world_size)
group = deepspeed.comm.new_group(ranks)
if self.rank in ranks:
self.dp_group = group

def get_model_parallel_rank(self):
return self.rank % self.tp_world_size

def get_model_parallel_world_size(self):
return self.tp_world_size

def get_data_parallel_rank(self):
return self.rank // self.tp_world_size

def get_data_parallel_world_size(self):
return self.world_size // self.tp_world_size

def get_data_parallel_group(self):
return self.dp_group

def get_model_parallel_group(self):
return self.tp_group


@pytest.mark.parametrize("shard_num", [6, 10])
@pytest.mark.parametrize("C, M, scale", [(92, 32, 1),(209, 128, 5)])
class TestPipelineCommunication(DistributedTest):
world_size = 8

def test(self, shard_num, C, M, scale):
tp_size = 2
world_size = deepspeed.comm.get_world_size()
E = world_size
ep_size = 4
config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
hidden_dim = M
device = get_accelerator().current_device_name()
tensor_parallel_expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, 4 * hidden_dim // tp_size),
torch.nn.ReLU(),
torch.nn.Linear(4 * hidden_dim // tp_size, hidden_dim))

model = MoE(
hidden_size=hidden_dim,
expert=tensor_parallel_expert,
num_experts=world_size * scale,
ep_size=ep_size,
use_residual=True,
enable_expert_tensor_parallelism=True,
)
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
optimizer=optimizer,
dist_init_required=False,
mpu=MPU(tp_size))
model.deepspeed_moe.shard_num = shard_num
input = torch.rand(E, C, M, device=device)

# pipeline alltoall with allgather
pipeline_output = model.deepspeed_moe.pipeline_alltoall_with_allgather(input)

# first alltoall, then allgather
alltoall_output = _AllToAll.apply(model.deepspeed_moe.ep_group, input)
gather_output = gather_tokens(alltoall_output, dim=1)
assert torch.allclose(pipeline_output, gather_output, atol=1e-07), f"pipeline_output {pipeline_output} is not equal to gather_output {gather_output}"
Loading