[Feature] Support fine-grained shared expert overlap (#5482)
Fine-grained control over shared expert overlap to prevent resource
contention.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -14,9 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Callable, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@@ -47,7 +50,20 @@ from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
|
||||
get_ascend_device_type, maybe_trans_nz,
|
||||
npu_stream_switch, shared_expert_dp_enabled,
|
||||
shared_experts_calculation_stream)
|
||||
shared_experts_calculation_stream, vllm_version_is)
|
||||
|
||||
@dataclass
|
||||
class FusedMoEResult:
|
||||
routed_out: torch.Tensor
|
||||
before_dispatch_evt: torch.npu.Event | None = None
|
||||
before_combine_evt: torch.npu.Event | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEEvents:
|
||||
before_routed_experts: torch.npu.Event
|
||||
before_dispatch: torch.npu.Event | None = field(default=None)
|
||||
before_combine: torch.npu.Event | None = field(default=None)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@@ -90,7 +106,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
@@ -137,7 +152,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
@@ -268,13 +282,13 @@ class AscendFusedMoE(FusedMoE):
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def forward_impl( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
|
||||
assert self.quant_method is not None
|
||||
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
# Load balancing for token distribution among experts in dummy_run
|
||||
@@ -359,9 +373,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
shared_experts=None,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
@@ -380,7 +391,14 @@ class AscendFusedMoE(FusedMoE):
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata)
|
||||
|
||||
return routed_out
|
||||
if return_with_event:
|
||||
return FusedMoEResult(
|
||||
routed_out=routed_out,
|
||||
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
|
||||
before_combine_evt=fused_experts_results.before_combine_evt)
|
||||
else:
|
||||
# The vLLM FusedMoE forward_impl does not return events.
|
||||
return routed_out
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
@@ -407,6 +425,74 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
self._gate = gate
|
||||
|
||||
# Wrap the quant_method's process_weights_after_loading to validate that
|
||||
# splitting shared expert computation (gate_up projection + activation,
|
||||
# then down projection) yields identical results to integrated
|
||||
# computation after weight loading.
|
||||
original_process_weights = self.quant_method.process_weights_after_loading
|
||||
|
||||
@wraps(original_process_weights)
|
||||
def wrapped_process_weights(*args, **kwargs):
|
||||
result = original_process_weights(*args, **kwargs)
|
||||
self._validate_shared_expert_consistency()
|
||||
return result
|
||||
|
||||
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
||||
|
||||
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
||||
shared_gate_up, _ = self._shared_experts.gate_up_proj(
|
||||
hidden_states) # type: ignore
|
||||
return shared_gate_up
|
||||
|
||||
def _shared_experts_part2(self, hidden_states: torch.Tensor,
|
||||
shared_gate_up: torch.Tensor):
|
||||
shared_act = self._shared_experts.act_fn(
|
||||
shared_gate_up) # type: ignore
|
||||
shared_out, _ = self._shared_experts.down_proj(
|
||||
shared_act) # type: ignore
|
||||
|
||||
# Qwen3-Next specific gating mechanism
|
||||
if hasattr(self._shared_experts, "expert_gate") and \
|
||||
self._shared_experts.expert_gate is not None:
|
||||
if vllm_version_is('0.13.0'):
|
||||
# TODO(jianzs): remove this branch after vLLM new version is
|
||||
# released
|
||||
gate_out = self._shared_experts.expert_gate(hidden_states) # type: ignore
|
||||
else:
|
||||
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
|
||||
shared_out = F.sigmoid(gate_out) * shared_out
|
||||
return shared_out
|
||||
|
||||
def _validate_shared_expert_consistency(self):
|
||||
"""Validate that split shared expert computation matches integrated
|
||||
computation."""
|
||||
test_input = torch.rand(
|
||||
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
|
||||
) * 2 - 1 # Random input for testing, scoped to [-1, 1]
|
||||
|
||||
integrated_out = self._shared_experts(test_input)
|
||||
part1_out = self._shared_experts_part1(test_input)
|
||||
split_out = self._shared_experts_part2(test_input, part1_out)
|
||||
|
||||
if not torch.allclose(integrated_out, split_out):
|
||||
diff = (integrated_out - split_out).abs()
|
||||
logger.error(
|
||||
"SharedFusedMoE shared experts split computation does not "
|
||||
"match the integrated computation.")
|
||||
logger.error(f"Max absolute difference: {diff.max().item()}")
|
||||
logger.error("Integrated output - sum: %s, norm: %s",
|
||||
integrated_out.sum().item(),
|
||||
integrated_out.norm().item())
|
||||
logger.error("Split output - sum: %s, norm: %s",
|
||||
split_out.sum().item(),
|
||||
split_out.norm().item())
|
||||
raise ValueError(
|
||||
"SharedFusedMoE shared experts split computation does not "
|
||||
"match the integrated computation.")
|
||||
logger.info_once(
|
||||
"SharedFusedMoE shared experts split computation matches the "
|
||||
"integrated computation.")
|
||||
|
||||
@property
|
||||
def gate(self) -> Optional[torch.nn.Module]:
|
||||
return self._gate if self.use_overlapped else None
|
||||
@@ -434,44 +520,67 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
)
|
||||
return shared_out, fused_out
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
shared_out = None
|
||||
if not self.multistream_overlap_gate:
|
||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
shared_experts_calculation_stream(
|
||||
).wait_stream( # type: ignore
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(
|
||||
shared_experts_calculation_stream(),
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Use a separate stream to run shared experts.
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
else:
|
||||
def _forward_shared_experts(self, hidden_states: torch.Tensor,
|
||||
fused_moe_evts: FusedMoEEvents):
|
||||
|
||||
def maybe_wait_event(evt: torch.npu.Event | None):
|
||||
if evt is not None:
|
||||
torch.npu.current_stream().wait_event(evt)
|
||||
|
||||
with npu_stream_switch(shared_experts_calculation_stream(),
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Ensure the shared experts wait for hidden_states to be ready.
|
||||
torch.npu.current_stream().wait_event(
|
||||
fused_moe_evts.before_routed_experts)
|
||||
# Execute the gate projection and activation concurrently with the
|
||||
# dispatch communication.
|
||||
maybe_wait_event(fused_moe_evts.before_dispatch)
|
||||
part1_out = self._shared_experts_part1(hidden_states)
|
||||
# Execute the down projection concurrently with the combine
|
||||
# communication.
|
||||
maybe_wait_event(fused_moe_evts.before_combine)
|
||||
shared_out = self._shared_experts_part2(hidden_states, part1_out)
|
||||
|
||||
# Make sure the default stream waits for the shared experts stream to
|
||||
# finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
shared_experts_calculation_stream())
|
||||
|
||||
# NOTE: This is exactly the opposite of
|
||||
# `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
return shared_out
|
||||
|
||||
def forward_impl( # type: ignore[override]
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
if self.multistream_overlap_gate:
|
||||
set_flash_common3_context(shared_experts=self._shared_experts)
|
||||
|
||||
routed_out = AscendFusedMoE.forward_impl(
|
||||
before_routed_experts = torch.npu.current_stream().record_event()
|
||||
fused_moe_results = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
return_with_event=True,
|
||||
)
|
||||
routed_out = fused_moe_results.routed_out
|
||||
|
||||
if not self.multistream_overlap_gate:
|
||||
# Make sure the default stream waits for the shared experts stream to finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
shared_experts_calculation_stream())
|
||||
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
else:
|
||||
if self.multistream_overlap_gate:
|
||||
fc3_context = get_flash_common3_context()
|
||||
assert fc3_context is not None
|
||||
shared_out = fc3_context.shared_out
|
||||
else:
|
||||
shared_out = self._forward_shared_experts(
|
||||
hidden_states,
|
||||
FusedMoEEvents(
|
||||
before_routed_experts=before_routed_experts,
|
||||
before_dispatch=fused_moe_results.before_dispatch_evt,
|
||||
before_combine=fused_moe_results.before_combine_evt,
|
||||
))
|
||||
|
||||
return shared_out, routed_out
|
||||
|
||||
Reference in New Issue
Block a user