[Feature] Support fine-grained shared expert overlap (#5482)

Fine-grained control over shared expert overlap to prevent resource
contention.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2026-01-17 11:53:22 +08:00
committed by GitHub
parent 48e10de8c9
commit 22f253142a
9 changed files with 203 additions and 130 deletions

View File

@@ -151,8 +151,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
padded_hidden_states_shape = hidden_states.shape
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
@@ -162,6 +162,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
@@ -174,7 +175,9 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states}
context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape
}
return hidden_states, router_logits, None, context_metadata
@@ -190,14 +193,25 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
assert context_metadata is not None
split_hidden_states = context_metadata["split_hidden_states"]
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
assert context_metadata is not None
# Cannot reuse `split_hidden_states` from prepare phase as it
# may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata[
"padded_hidden_states_shape"]
gathered_hidden_states = torch.empty(
padded_hidden_states_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
split_hidden_states = torch.tensor_split(
gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(split_hidden_states, dim=0)
hidden_states = gathered_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
@@ -249,7 +263,6 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1:
@@ -257,6 +270,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
mc2_mask = split_mc2_mask[self.tp_rank]
padded_hidden_states_shape = hidden_states.shape
if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape
target_pad_length = forward_context.padded_num_tokens
@@ -268,6 +282,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
# Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp:
@@ -280,7 +295,9 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states}
context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape,
}
return hidden_states, router_logits, mc2_mask, context_metadata