performance optimization, usability optimization and API compatibility adjustments for deepseek with npu graph mode (#731)
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
1. Improve inference speed and usability for deepsek models with NPU
graph mode.
2. Modify some codes to adapt to CANN 8.1.RC1.beta1.
3. Add a switch for NPU graph mode and its cache.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
This PR provides an experimental configuration to enable NPU graph mode
for Deepseek models. User can set
additional_config={'enable_graph_mode': True} to try this feature. Note
that this feature currently only supports for V0 engine.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
This patch was tested with the newest torch_npu 2.5.1
(https://pypi.org/project/torch-npu/#files) and CANN 8.1.RC1.beta1
toolkit&nnal&kernels
(https://www.hiascend.com/developer/download/community/result?module=cann)
released in 25/30 April.
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -590,14 +590,14 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
use_npu_graph = graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if self.num_prefills == 0 and use_torchair_graph:
|
||||
if self.num_prefills == 0 and use_npu_graph:
|
||||
num_seqs = len(seq_lens)
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
||||
self.block_tables.extend([[]] * graph_pad_size)
|
||||
@@ -915,7 +915,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
|
||||
k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
@@ -1123,9 +1123,17 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
if self.enable_graph_mode:
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
# shape of query for npu graph mode should be:
|
||||
# [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
block_size = kv_cache[0].shape[1]
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
@@ -1133,14 +1141,14 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BNSD",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=kv_cache[0].shape[1],
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
)
|
||||
attn_output = attn_output.view(num_tokens, -1,
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
@@ -39,10 +40,13 @@ from vllm.distributed import (get_dp_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@@ -55,15 +59,84 @@ from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # ruff: noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention,
|
||||
DeepseekV2MLP)
|
||||
DeepseekV2MLAAttention)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
|
||||
|
||||
class CustomDeepseekV2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
|
||||
self.is_dynamic_quant = not isinstance(
|
||||
self.gate_up_proj.quant_method,
|
||||
UnquantizedLinearMethod) and isinstance(
|
||||
self.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class CustomDeepseekV2MoE(nn.Module):
|
||||
@@ -119,7 +192,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
self.shared_experts = CustomDeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
@@ -392,7 +465,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
self.mlp = CustomDeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
@@ -442,8 +515,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(self.mlp,
|
||||
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
if isinstance(
|
||||
self.mlp,
|
||||
CustomDeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# Scaling the DeepseekV2MLP output, it is the input of
|
||||
# input_layernorm of next decoder layer.
|
||||
@@ -582,4 +656,4 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
pass
|
||||
@@ -221,10 +221,16 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
||||
cos_cached = cos_cached.to(dtype)
|
||||
sin_cached = sin_cached.to(dtype)
|
||||
cache = torch.cat([freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale],
|
||||
dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
|
||||
def deepseek_rope_init_func(
|
||||
|
||||
@@ -124,7 +124,10 @@ class NPUPlatform(Platform):
|
||||
enforce_eager = True
|
||||
logger.warning(
|
||||
"NPU compilation support pending. Will be available in future CANN and "
|
||||
"torch_npu releases. Using default: enforce_eager=True")
|
||||
"torch_npu releases. NPU graph mode is currently experimental and disabled "
|
||||
"by default. You can just adopt additional_config={'enable_graph_mode': True} "
|
||||
"to serve deepseek models with NPU graph mode on vllm-ascend with V0 engine. "
|
||||
)
|
||||
|
||||
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
|
||||
logger.info("Compilation disabled, using eager mode by default")
|
||||
@@ -150,6 +153,11 @@ class NPUPlatform(Platform):
|
||||
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
|
||||
)
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
if enable_graph_mode and envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"NPU graph mode is still experimental and not supported for V1 currently, "
|
||||
"it has been disabled automatically.")
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
|
||||
@@ -62,38 +62,38 @@ def apply_mlp(x: torch.Tensor,
|
||||
h = x
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else \
|
||||
torch.float16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[h],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype)
|
||||
gate_up_out = gate_up_out_list[0]
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(x=[h],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# swiglu
|
||||
swiglu_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
swiglu_out, swiglu_out_scale = torch_npu.npu_dynamic_quant(swiglu_out)
|
||||
swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=gate_up_out,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# down_proj
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[swiglu_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype)
|
||||
return down_out_list[0]
|
||||
down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
return down_out
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
@@ -363,7 +363,10 @@ class AscendW8A8DynamicLinearMethod:
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
|
||||
|
||||
@@ -508,7 +511,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
|
||||
@@ -69,6 +69,8 @@ if TYPE_CHECKING:
|
||||
|
||||
TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU")
|
||||
ENCODER_NUM = 0
|
||||
# if True, allow tensor initialization and casting with internal format (e.g., NZ)
|
||||
torch.npu.config.allow_internal_format = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -864,10 +866,13 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
self.vllm_config.compilation_config.max_capture_size
|
||||
|
||||
self.enable_graph_mode = False
|
||||
self.use_cached_npu_graph = False
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
self.use_cached_npu_graph = additional_config.get(
|
||||
"use_cached_npu_graph", False)
|
||||
|
||||
self.has_inner_state = model_config.has_inner_state
|
||||
|
||||
@@ -981,12 +986,20 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.compile_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
else:
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user