From d6bfae8eeebedf677b643b712d367a3a69c9cce4 Mon Sep 17 00:00:00 2001 From: sunbaosong <13793883820@163.com> Date: Tue, 6 May 2025 10:12:07 +0800 Subject: [PATCH] support 32K model len on deepseek r1 W8A8 (#728) ### What this PR does / why we need it? Optimize NPU memory usage. https://github.com/vllm-project/vllm-ascend/issues/723 vllm v0.8.4.rc2 and DeepSeek R1 can only support a model length of 16K. When attempting to run with a model length of 32K, an "Out of Memory" (OOM) error will occur. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed Signed-off-by: sunbaosong <13793883820@163.com> --- vllm_ascend/quantization/w8a8_dynamic.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a558c89..bcd313d 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -276,8 +276,7 @@ def fused_experts(hidden_states: torch.Tensor, group_list_type=group_list_type) if expert_map is not None: - weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) - + down_out_list.mul_(sorted_weights.unsqueeze(1)) final_hidden_states = torch.zeros(*original_shape, device=hidden_states.device, dtype=dtype) @@ -286,10 +285,8 @@ def fused_experts(hidden_states: torch.Tensor, valid_token_mask = torch.arange( 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < num_valid_tokens - valid_output = torch.where( - valid_token_mask, weighted_down_out, - torch.zeros_like(weighted_down_out)).to(dtype) - final_hidden_states.index_add_(0, sorted_token_indices, valid_output) + down_out_list.mul_(valid_token_mask) + final_hidden_states.index_add_(0, sorted_token_indices, down_out_list) else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available.