performance optimization, usability optimization and API compatibility adjustments for deepseek with npu graph mode (#731)

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
1. Improve inference speed and usability for deepsek models with NPU
graph mode.
2. Modify some codes to adapt to CANN 8.1.RC1.beta1.
3. Add a switch for NPU graph mode and its cache.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
This PR provides an experimental configuration to enable NPU graph mode
for Deepseek models. User can set
additional_config={'enable_graph_mode': True} to try this feature. Note
that this feature currently only supports for V0 engine.


### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
This patch was tested with the newest torch_npu 2.5.1
(https://pypi.org/project/torch-npu/#files) and CANN 8.1.RC1.beta1
toolkit&nnal&kernels
(https://www.hiascend.com/developer/download/community/result?module=cann)
released in 25/30 April.

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-05-01 13:51:42 +08:00
committed by GitHub
parent 399b03830d
commit 84e2ed898b
6 changed files with 163 additions and 51 deletions

View File

@@ -30,6 +30,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
@@ -39,10 +40,13 @@ from vllm.distributed import (get_dp_group, get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -55,15 +59,84 @@ from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # ruff: noqa: E501
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
DeepseekV2DecoderLayer,
DeepseekV2MLAAttention,
DeepseekV2MLP)
DeepseekV2MLAAttention)
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
class CustomDeepseekV2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
self.is_dynamic_quant = not isinstance(
self.gate_up_proj.quant_method,
UnquantizedLinearMethod) and isinstance(
self.gate_up_proj.quant_method.quant_method,
AscendW8A8DynamicLinearMethod)
def forward(self, x):
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class CustomDeepseekV2MoE(nn.Module):
@@ -119,7 +192,7 @@ class CustomDeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
self.shared_experts = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
@@ -392,7 +465,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepseekV2MLP(
self.mlp = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
@@ -442,8 +515,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
if isinstance(
self.mlp,
CustomDeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
@@ -582,4 +656,4 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass
pass