[BugFix]Fix eplb problems when using dynamic eplb. (#3364)

### What this PR does / why we need it?
When using dynamic eplb,it will be blocking by nz tensor.We fix these
prolems by clone src tensor and recv tensor.

### Does this PR introduce any user-facing change?

### How was this patch tested?
Qwen3_moe in A3.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: offline0806 <3337230449@qq.com>
Co-authored-by: offline0806 <3337230449@qq.com>
This commit is contained in:
offline893
2025-10-11 14:04:02 +08:00
committed by GitHub
parent ca05f7d632
commit 82b6c846ca
8 changed files with 58 additions and 34 deletions

View File

@@ -150,7 +150,8 @@ class MoECommMethod(ABC):
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
need_trans=need_trans)
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output)

View File

@@ -63,7 +63,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
fusion: bool = False) -> torch.Tensor:
fusion: bool = False,
dynamic_eplb: bool = False) -> torch.Tensor:
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
@@ -79,7 +80,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
if fusion:
if fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
@@ -134,7 +135,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
if fusion:
if fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
@@ -229,7 +230,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True) -> torch.Tensor:
need_trans: bool = True,
dynamic_eplb: bool = False) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
@@ -241,7 +243,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
fusion=fusion)
fusion=fusion,
dynamic_eplb=dynamic_eplb)
else:
return unquant_apply_mlp(hidden_states=hidden_states,
w1=w1,