[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216)
### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
This commit is contained in:
@@ -379,10 +379,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1=[layer.w13_weight],
|
||||
w2=[layer.w2_weight],
|
||||
w1_scale=[layer.w13_weight_scale],
|
||||
w2_scale=[layer.w2_weight_scale],
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
|
||||
@@ -236,13 +236,24 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
if self.dynamic_eplb:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.w13_weight_scale_fp32_list
|
||||
w2 = layer.w2_weight_list
|
||||
w2_scale = layer.w2_weight_scale_list
|
||||
else:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.w13_weight_scale_fp32]
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale_fp32,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
@@ -274,3 +285,25 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
if self.dynamic_eplb:
|
||||
layer.w13_weight_list = [
|
||||
weight.clone()
|
||||
for weight in layer.w13_weight.data.unbind(dim=0)
|
||||
]
|
||||
layer.w2_weight_list = [
|
||||
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
|
||||
]
|
||||
layer.w13_weight_scale_fp32_list = [
|
||||
weight.clone()
|
||||
for weight in layer.w13_weight_scale.data.unbind(dim=0)
|
||||
]
|
||||
layer.w2_weight_scale_list = [
|
||||
weight.clone()
|
||||
for weight in layer.w2_weight_scale.data.unbind(dim=0)
|
||||
]
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
torch.npu.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user