[Feature] Support fine-grained shared expert overlap (#5482)
Fine-grained control over shared expert overlap to prevent resource
contention.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -151,8 +151,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
split_hidden_states = None
|
||||
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
@@ -162,6 +162,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
@@ -174,7 +175,9 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {"split_hidden_states": split_hidden_states}
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape
|
||||
}
|
||||
|
||||
return hidden_states, router_logits, None, context_metadata
|
||||
|
||||
@@ -190,14 +193,25 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
assert context_metadata is not None
|
||||
|
||||
split_hidden_states = context_metadata["split_hidden_states"]
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
assert context_metadata is not None
|
||||
# Cannot reuse `split_hidden_states` from prepare phase as it
|
||||
# may share memory with original hidden_states. Since shared
|
||||
# experts may use the original tensor, reusing it would cause
|
||||
# in-place modification during all_gather, corrupting the data.
|
||||
padded_hidden_states_shape = context_metadata[
|
||||
"padded_hidden_states_shape"]
|
||||
gathered_hidden_states = torch.empty(
|
||||
padded_hidden_states_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
split_hidden_states = torch.tensor_split(
|
||||
gathered_hidden_states, self.tp_size, dim=0)
|
||||
dist.all_gather(list(split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(split_hidden_states, dim=0)
|
||||
hidden_states = gathered_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
@@ -249,7 +263,6 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
split_hidden_states = None
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
@@ -257,6 +270,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
@@ -268,6 +282,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
@@ -280,7 +295,9 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {"split_hidden_states": split_hidden_states}
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape,
|
||||
}
|
||||
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
|
||||
Reference in New Issue
Block a user