support deepseek quant & mix-parallel with graphmode (#585)

### What this PR does / why we need it?
1. support deepseek with w8a8 quant;
2. support deepseek with mix-parallel(multi-DP, EP+TP);
3. support deepseek with graphmode.
---------

Signed-off-by: wen-jie666 <wenjie39@huawei.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
zzzzwwjj
2025-04-23 16:23:25 +08:00
committed by GitHub
parent e74331a1ed
commit 5c6d05a59e
13 changed files with 520 additions and 221 deletions

View File

@@ -11,8 +11,6 @@
import gc import gc
import os import os
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
def main(): def main():
dp_rank = int(os.environ['RANK']) dp_rank = int(os.environ['RANK'])
@@ -20,8 +18,8 @@ def main():
dp_size = int(os.environ['WORLD_SIZE']) dp_size = int(os.environ['WORLD_SIZE'])
master_addr = os.environ['MASTER_ADDR'] master_addr = os.environ['MASTER_ADDR']
master_port = os.environ['MASTER_PORT'] master_port = os.environ['MASTER_PORT']
tp_size = 4 tp_size = 1
etp_size = 2 etp_size = 1
os.environ["VLLM_DP_RANK"] = str(dp_rank) os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -58,15 +56,15 @@ def main():
max_tokens=4, max_tokens=4,
min_tokens=4) min_tokens=4)
# Create an LLM. # Create an LLM.
llm = LLM( llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat",
model="deepseek-ai/DeepSeek-V2-Lite-Chat", tensor_parallel_size=tp_size,
tensor_parallel_size=tp_size, trust_remote_code=True,
trust_remote_code=True, max_model_len=4096,
expert_tensor_parallel_size=etp_size, max_num_seqs=num_seqs,
max_model_len=4096, additional_config={
max_num_seqs=num_seqs, 'expert_tensor_parallel_size': etp_size,
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0, 'enable_graph_mode': False,
) })
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:

View File

@@ -6,15 +6,13 @@ export HCCL_SOCKET_IFNAME=${ifname}
# dp_size = node_size * dp_per_node # dp_size = node_size * dp_per_node
node_size=1 node_size=1
node_rank=0 node_rank=0
dp_per_node=2 dp_per_node=4
master_addr=127.0.0.1 master_addr=127.0.0.1
master_port=12345 master_port=12345
rm -rf ./.torchair_cache/ rm -rf ./.torchair_cache/
rm -rf ./dynamo_* rm -rf ./dynamo_*
rm -rf /root/ascend/log/debug/plog/* rm -rf /root/ascend/log/debug/plog/*
export VLLM_ENABLE_GRAPH_MODE=0
export VLLM_ENABLE_MC2=0
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \ torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \ --node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \

View File

@@ -27,6 +27,7 @@ try:
except ImportError: except ImportError:
print("Failed to import torch_npu.") print("Failed to import torch_npu.")
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
@@ -36,9 +37,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.config import get_current_vllm_config
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
from vllm_ascend.worker.model_runner import ( from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
@@ -913,6 +914,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.w_kc = None self.w_kc = None
self.w_vc = None self.w_vc = None
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def exec_kv( def exec_kv(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -1084,7 +1091,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.num_heads, -1) self.num_heads, -1)
# TODO: Replace the env with more flexible expressions # TODO: Replace the env with more flexible expressions
if VLLM_ENABLE_GRAPH_MODE == '1': if self.enable_graph_mode:
if len(kv_cache) > 0 and kv_cache[0].numel( if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.num_prefills > 0: ) > 0 and attn_metadata.num_prefills > 0:
slots = attn_metadata.slot_mapping slots = attn_metadata.slot_mapping
@@ -1141,7 +1148,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
) )
elif attn_metadata.decode_metadata: elif attn_metadata.decode_metadata:
assert kv_cache is not None assert kv_cache is not None
if VLLM_ENABLE_GRAPH_MODE == '1': if self.enable_graph_mode:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)

View File

@@ -26,13 +26,13 @@
# """Inference-only DeepseekV2/DeepseekV3 model.""" # """Inference-only DeepseekV2/DeepseekV3 model."""
import os import os
from typing import Any, Dict, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config) get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group, from vllm.distributed import (get_dp_group, get_pp_group,
@@ -64,7 +64,6 @@ from vllm.model_executor.models.utils import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
class CustomDeepseekV2MoE(nn.Module): class CustomDeepseekV2MoE(nn.Module):
@@ -133,7 +132,7 @@ class CustomDeepseekV2MoE(nn.Module):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs batch_size = vllm_config.scheduler_config.max_num_seqs
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1 self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros( self.final_hidden_states = torch.zeros(
@@ -309,38 +308,36 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
if VLLM_ENABLE_GRAPH_MODE == "1": self.enable_graph_mode = False
self.forward = self.forward_torchair additional_config = get_current_vllm_config().additional_config
else: if additional_config:
self.forward = self.forward_eager # type: ignore self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def forward_torchair(self, def forward(
positions: torch.Tensor, self,
hidden_states: torch.Tensor, positions: torch.Tensor,
kv_cache: torch.Tensor = None, hidden_states: torch.Tensor,
attn_metadata=None): kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0] ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq) hidden_states_or_q_c = self.q_a_layernorm(ckq)
else: else:
hidden_states_or_q_c = hidden_states hidden_states_or_q_c = hidden_states
return self.mla_attn(hidden_states_or_q_c, hidden_states, None, if self.enable_graph_mode:
kv_cache, attn_metadata) return self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
def forward_eager(self, positions: torch.Tensor, hidden_states, None, kv_cache,
hidden_states: torch.Tensor): attn_metadata)
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else: else:
hidden_states_or_q_c = hidden_states kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) return self.mla_attn(hidden_states_or_q_c,
return self.mla_attn(hidden_states_or_q_c, kv_c_normed,
kv_c_normed, k_pe,
k_pe, output_shape=hidden_states.shape)
output_shape=hidden_states.shape)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -408,6 +405,54 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
class CustomDeepseekV2Model(nn.Module): class CustomDeepseekV2Model(nn.Module):
@@ -459,7 +504,9 @@ class CustomDeepseekV2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors], kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
@@ -473,8 +520,13 @@ class CustomDeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for i in range(self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual) layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@@ -514,6 +566,20 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass pass

View File

@@ -330,17 +330,16 @@ def native_grouped_topk(
def select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
use_grouped_topk: bool, use_grouped_topk: bool,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Select top-k experts based on router logits. Select top-k experts based on router logits.
@@ -364,7 +363,6 @@ def select_experts(
Raises: Raises:
ValueError: If an unsupported scoring function is provided. ValueError: If an unsupported scoring function is provided.
""" """
if custom_routing_function is not None: if custom_routing_function is not None:
raise NotImplementedError( raise NotImplementedError(
"Custom routing function is not supported now") "Custom routing function is not supported now")
@@ -466,21 +464,36 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
is_prefill=False, is_prefill=False,
**kwargs, **kwargs,
): ):
# set prefill as false always, should fix this # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts( if global_num_experts == 256:
hidden_states=x, topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits=router_logits, router_logits,
top_k=top_k, k=top_k, # topk当前写8
use_grouped_topk=use_grouped_topk, bias=e_score_correction_bias,
renormalize=renormalize, k_group=topk_group, # fix: 4
topk_group=topk_group, group_count=num_expert_group, # fix 8
num_expert_group=num_expert_group, group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
custom_routing_function=custom_routing_function, renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
scoring_func=scoring_func, norm_type=1, # 0: softmax; 1: sigmoid(fix)
e_score_correction_bias=e_score_correction_bias, # out_flag=False, # todo new api; 第三个输出是否输出
is_prefill=is_prefill) # y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill: if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
return fused_experts_with_mc2( return fused_experts_with_mc2(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -611,10 +624,11 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k real_top_k = self.top_k
if self.dp_size > 1: if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
) == 1 and not is_prefill: ) == 1 and not is_prefill:
... ...
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore elif int(os.environ.get("USING_LCCL_COM",
'0')) == 1: # type: ignore
hidden_states = get_dp_group().all_gather( hidden_states = get_dp_group().all_gather(
hidden_states, 0, False) hidden_states, 0, False)
router_logits = get_dp_group().all_gather( router_logits = get_dp_group().all_gather(
@@ -631,7 +645,7 @@ class AscendFusedMoE(FusedMoE):
top_k=real_top_k, top_k=real_top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map, expert_map=self.expert_map,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
@@ -641,7 +655,7 @@ class AscendFusedMoE(FusedMoE):
is_prefill=is_prefill) is_prefill=is_prefill)
if self.dp_size > 1: if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
) == 1 and not is_prefill: ) == 1 and not is_prefill:
... ...
else: else:

View File

@@ -24,6 +24,7 @@ import torch_npu # noqa: F401
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import logger from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum from vllm.platforms import Platform, PlatformEnum
from vllm.utils import supports_dynamo
CUSTOM_OP_ENABLED = False CUSTOM_OP_ENABLED = False
try: try:
@@ -119,6 +120,15 @@ class NPUPlatform(Platform):
compilation_config.level) compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION compilation_config.level = CompilationLevel.NO_COMPILATION
if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(
"enable_graph_mode", False)
if enable_graph_mode and not supports_dynamo():
logger.warning(
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
)
vllm_config.additional_config["enable_graph_mode"] = False
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
if parallel_config and parallel_config.worker_cls == "auto": if parallel_config and parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:

View File

@@ -310,21 +310,22 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
return self.quant_method.apply(layer, x, router_logits, top_k, return self.quant_method.apply(layer, x, router_logits, top_k,
renormalize, use_grouped_topk, renormalize, use_grouped_topk,
topk_group, num_expert_group,
global_num_experts, expert_map, global_num_experts, expert_map,
topk_group, num_expert_group,
custom_routing_function, scoring_func, custom_routing_function, scoring_func,
e_score_correction_bias) e_score_correction_bias, is_prefill)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"): if hasattr(self.quant_method, "process_weights_after_loading"):

View File

@@ -23,10 +23,8 @@ import torch_npu
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
input_offset: torch.Tensor): input_offset: torch.Tensor):
out = torch.empty_like(in_tensor, dtype=torch.int8) return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset, torch.qint8, -1, True)
out)
return out
class AscendW8A8LinearMethod: class AscendW8A8LinearMethod:
@@ -88,7 +86,11 @@ class AscendW8A8LinearMethod:
) -> torch.Tensor: ) -> torch.Tensor:
original_dtype = x.dtype original_dtype = x.dtype
if original_dtype != torch.int8: if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset) x = quant_per_tensor(
x,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None quant_bias = layer.quant_bias if tp_rank == 0 else None
return torch_npu.npu_quant_matmul( return torch_npu.npu_quant_matmul(
x, x,
@@ -99,6 +101,13 @@ class AscendW8A8LinearMethod:
) )
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False)
if self.transpose_weight: if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data)

View File

@@ -15,14 +15,183 @@
# limitations under the License. # limitations under the License.
# #
import os
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
import torch import torch
import torch_npu import torch_npu
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.ops.fused_moe import select_experts
def apply_mlp(x: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj
Args:
x: input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
w2_scale: weights2 scale with shape (num_experts, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)
Returns:
hidden_states: output hidden states after MLP.
"""
if dynamic_scale is None:
h, pertoken_scale = torch_npu.npu_dynamic_quant(x)
else:
h = x
pertoken_scale = dynamic_scale
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else \
torch.float16
# gmm1: gate_up_proj
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[h],
weight=[w1],
scale=[w1_scale],
per_token_scale=[pertoken_scale],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype)
gate_up_out = gate_up_out_list[0]
# swiglu
swiglu_out = torch_npu.npu_swiglu(gate_up_out)
swiglu_out, swiglu_out_scale = torch_npu.npu_dynamic_quant(swiglu_out)
# down_proj
down_out_list = torch_npu.npu_grouped_matmul(
x=[swiglu_out],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype)
return down_out_list[0]
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
# hidden_states = hidden_states.bfloat16()
kwargs = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
rank = torch.distributed.get_rank()
quant_mode = 2
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_szie = torch.distributed.get_world_size()
tp_size = world_szie // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
if quant_mode == 0:
dynamic_scale = None
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
# moeCombine
kwargs = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
return hidden_states
def fused_experts(hidden_states: torch.Tensor, def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
@@ -75,11 +244,10 @@ def fused_experts(hidden_states: torch.Tensor,
dtype=torch.int64) dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
token_counts = token_counts[:num_experts] expert_tokens = token_counts[:num_experts]
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
# Rearrange hidden_states # Rearrange hidden_states
sorted_hidden_states = hidden_states[sorted_token_indices] sorted_hidden_states = hidden_states[sorted_token_indices]
group_list_type = 1
else: else:
row_idx_len = num_tokens * top_k row_idx_len = num_tokens * top_k
row_idx = torch.arange(0, row_idx = torch.arange(0,
@@ -97,46 +265,15 @@ def fused_experts(hidden_states: torch.Tensor,
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts) expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
quant_x, x_dynamic_scale = torch_npu.npu_dynamic_quant( down_out_list = apply_mlp(sorted_hidden_states,
sorted_hidden_states) w1,
del sorted_hidden_states w1_scale,
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else torch.float16 w2,
w2_scale,
gate_up_out_list = torch_npu.npu_grouped_matmul( expert_tokens,
x=[quant_x], group_list_type=group_list_type)
weight=[w1],
scale=[w1_scale],
per_token_scale=[x_dynamic_scale],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=output_dtype)
del quant_x
gate_up_out_list = gate_up_out_list[0] if len(
gate_up_out_list) == 1 else torch.cat(gate_up_out_list, dim=0)
gate_up_out_list = torch_npu.npu_swiglu(gate_up_out_list)
quant_gate_up_out_list, gate_up_out_dynamic_scale = torch_npu.npu_dynamic_quant(
gate_up_out_list)
del gate_up_out_list
down_out_list = torch_npu.npu_grouped_matmul(
x=[quant_gate_up_out_list],
weight=[w2],
scale=[w2_scale],
per_token_scale=[gate_up_out_dynamic_scale],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=output_dtype)
del quant_gate_up_out_list
down_out_list = down_out_list[0] if len(down_out_list) == 1 else torch.cat(
down_out_list, dim=0)
if expert_map is not None: if expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
@@ -144,12 +281,18 @@ def fused_experts(hidden_states: torch.Tensor,
final_hidden_states = torch.zeros(*original_shape, final_hidden_states = torch.zeros(*original_shape,
device=hidden_states.device, device=hidden_states.device,
dtype=dtype) dtype=dtype)
final_hidden_states.index_add_(0, sorted_token_indices,
weighted_down_out) num_valid_tokens = mask.sum()
# TODO: This should not happen! Look into it! valid_token_mask = torch.arange(
# fill nan with 0.0 0, sorted_token_indices.shape[0],
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0 device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
else: else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing( final_hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list, down_out_list,
skip1=None, skip1=None,
@@ -157,7 +300,8 @@ def fused_experts(hidden_states: torch.Tensor,
bias=None, bias=None,
scales=topk_weights, scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx, expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids) export_for_source_row=topk_ids,
)
del down_out_list del down_out_list
if len(original_shape) == 3: if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape) final_hidden_states = final_hidden_states.view(original_shape)
@@ -230,6 +374,18 @@ class AscendW8A8DynamicFusedMoEMethod:
def __init__(self): def __init__(self):
self.transpose_weight = True self.transpose_weight = True
ep_group = get_ep_group()
try:
device_group = ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
@staticmethod @staticmethod
def get_weight(num_experts: int, intermediate_size_per_partition: int, def get_weight(num_experts: int, intermediate_size_per_partition: int,
hidden_sizes: int, hidden_sizes: int,
@@ -272,48 +428,78 @@ class AscendW8A8DynamicFusedMoEMethod:
dtype=params_dtype) dtype=params_dtype)
return param_dict return param_dict
@staticmethod
def apply( def apply(
self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch" 1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids = select_experts( # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
hidden_states=x, if global_num_experts == 256:
router_logits=router_logits, topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
top_k=top_k, router_logits,
use_grouped_topk=use_grouped_topk, k=top_k, # topk当前写8
renormalize=renormalize, bias=e_score_correction_bias,
topk_group=topk_group, k_group=topk_group, # fix: 4
num_expert_group=num_expert_group, group_count=num_expert_group, # fix 8
custom_routing_function=custom_routing_function, group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
scoring_func=scoring_func, renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
e_score_correction_bias=e_score_correction_bias, norm_type=1, # 0: softmax; 1: sigmoid(fix)
) # out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
return fused_experts(hidden_states=x, if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
w1=layer.w13_weight, return fused_experts_with_mc2(
w1_scale=layer.w13_weight_scale, hidden_states=x,
w2=layer.w2_weight, w1=layer.w13_weight,
w2_scale=layer.w2_weight_scale, w2=layer.w2_weight,
topk_weights=topk_weights, w1_scale=layer.w13_weight_scale,
topk_ids=topk_ids, w2_scale=layer.w2_weight_scale,
top_k=top_k, topk_weights=topk_weights,
expert_map=expert_map) topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
else:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.transpose_weight: if self.transpose_weight:

View File

@@ -16,8 +16,6 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/worker.py # Adapted from vllm-project/vllm/vllm/worker/worker.py
# #
import os
import torch import torch
import torch_npu # noqa: F401 import torch_npu # noqa: F401
from packaging.version import Version from packaging.version import Version
@@ -25,8 +23,6 @@ from vllm.logger import logger
import vllm_ascend.envs as envs import vllm_ascend.envs as envs
VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')
def try_register_lib(lib_name: str, lib_info: str = ""): def try_register_lib(lib_name: str, lib_info: str = ""):
import importlib import importlib

View File

@@ -17,53 +17,66 @@
# limitations under the License. # limitations under the License.
# #
from typing import List, Tuple from typing import Any, List
import torch import torch
from vllm.config import get_current_vllm_config
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
def allocate_kv_cache( def allocate_kv_cache(
self, self,
num_blocks: int, num_blocks: int,
device: str, device: str,
) -> List[Tuple]: ) -> List[Any]:
"""Allocates KV cache on the specified device.""" """Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape( kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size) num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[Tuple] = [] kv_cache: List[Any] = []
# Align entries so they are 256 byte aligned for better performance additional_config = get_current_vllm_config().additional_config
# Primarily targets MLA as this typically only ends up having entries if additional_config and additional_config.get("enable_graph_mode", False):
# be 128 byte aligned. # Align entries so they are 256 byte aligned for better performance
alloc_shape = kv_cache_shape # Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
alloc_shape = kv_cache_shape
for _ in range(self.num_attention_layers): for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that # null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out. # block to be zeroed-out.
# We zero-out everything for simplicity. # We zero-out everything for simplicity.
layer_kv_cache_nope = torch.zeros( layer_kv_cache_nope = torch.zeros(
alloc_shape[:-1] + alloc_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ), (self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device=device) device=device)
layer_kv_cache_pe = torch.zeros( layer_kv_cache_pe = torch.zeros(
alloc_shape[:-1] + alloc_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim, ), (self.model_config.hf_text_config.qk_rope_head_dim, ),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device=device) device=device)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D # when entry_shape is higher than 1D
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe)) kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
else:
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache)
return kv_cache return kv_cache
if VLLM_ENABLE_GRAPH_MODE == '1': CacheEngine._allocate_kv_cache = allocate_kv_cache
CacheEngine._allocate_kv_cache = allocate_kv_cache

View File

@@ -32,7 +32,7 @@ import torch_npu
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
@@ -56,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists, from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
is_pin_memory_available, supports_dynamo) is_pin_memory_available)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
@@ -546,8 +546,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
} }
# Add graph_pad_size here # Add graph_pad_size here
if self.runner.vllm_config.compilation_config.level ==\ if self.runner.enable_graph_mode:
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len( graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
seq_lens) seq_lens)
else: else:
@@ -609,8 +608,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
] ]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
if self.runner.vllm_config.compilation_config.level ==\ if self.runner.enable_graph_mode:
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
torch._dynamo.mark_static(input_tokens_tensor) torch._dynamo.mark_static(input_tokens_tensor)
torch._dynamo.mark_static(input_positions_tensor) torch._dynamo.mark_static(input_positions_tensor)
torch._dynamo.mark_static(attn_metadata.block_tables) torch._dynamo.mark_static(attn_metadata.block_tables)
@@ -871,6 +869,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.max_batchsize_to_capture = \ self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size self.vllm_config.compilation_config.max_capture_size
self.enable_graph_mode = False
additional_config = vllm_config.additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
self.has_inner_state = model_config.has_inner_state self.has_inner_state = model_config.has_inner_state
self.in_profile_run = False self.in_profile_run = False
@@ -971,8 +975,7 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
# adapter torch compile with npu_backend # adapter torch compile with npu_backend
if self.vllm_config.compilation_config.level ==\ if self.enable_graph_mode:
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
import torchair # type: ignore import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore from torchair import patch_for_hcom # type: ignore
@@ -1279,15 +1282,12 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.attn_state.begin_forward(model_input) self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
if self.vllm_config.compilation_config.level ==\ # TODO(zzzzwwjj): Do we need to do it every time?
CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): if self.enable_graph_mode:
torch._dynamo.mark_static(model_input.input_tokens) torch._dynamo.mark_static(model_input.input_tokens)
torch._dynamo.mark_static(model_input.input_positions) torch._dynamo.mark_static(model_input.input_positions)
torch._dynamo.mark_static(model_input.attn_metadata.block_tables) torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping) torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
torch._dynamo.mark_static(
model_input.attn_metadata.query_start_loc)
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
for kv in kv_caches: for kv in kv_caches:
if isinstance(kv, tuple): if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[0])
@@ -1298,7 +1298,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
virtual_engine = model_input.virtual_engine virtual_engine = model_input.virtual_engine
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
previous_hidden_states = kwargs.get("previous_hidden_states") previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and self.vllm_config.compilation_config.level > 0: if prefill_meta is None and self.enable_graph_mode:
model_executable = self.compile_model model_executable = self.compile_model
# Note: graph_batch_size value not same as GPU # Note: graph_batch_size value not same as GPU
graph_batch_size = model_input.input_tokens.shape[ # type: ignore graph_batch_size = model_input.input_tokens.shape[ # type: ignore
@@ -1341,9 +1341,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {} } if self.has_inner_state else {}
if self.vllm_config.compilation_config.level ==\ if self.enable_graph_mode:
CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): model_kwargs: Dict[str, Any] = {"inputs_embeds": None}
model_kwargs = {"inputs_embeds": None}
else: else:
model_kwargs = {} model_kwargs = {}
if previous_hidden_states is not None: if previous_hidden_states is not None:
@@ -1360,6 +1359,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.vllm_config, virtual_engine): self.vllm_config, virtual_engine):
if model_input.attn_metadata is not None: if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions model_input.attn_metadata.input_positions = model_input.input_positions
if self.enable_graph_mode:
model_kwargs["kv_caches"] = kv_caches
model_kwargs["attn_metadata"] = model_input.attn_metadata
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
@@ -1430,8 +1432,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
hidden_states = hidden_or_intermediate_states.index_select( hidden_states = hidden_or_intermediate_states.index_select(
0, indices) 0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states output.prefill_hidden_states = hidden_or_intermediate_states
elif self.vllm_config.compilation_config.level == \ elif self.enable_graph_mode:
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
hidden_states = hidden_or_intermediate_states[:len(indices)] hidden_states = hidden_or_intermediate_states[:len(indices)]
else: else:
hidden_states = hidden_or_intermediate_states hidden_states = hidden_or_intermediate_states

View File

@@ -24,7 +24,7 @@ import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
@@ -300,7 +300,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
from contextlib import nullcontext from contextlib import nullcontext
context = nullcontext() # type: ignore context = nullcontext() # type: ignore
with context: with context:
self._init_cache_engine() with set_current_vllm_config(self.vllm_config):
self._init_cache_engine()
self._warm_up_model() self._warm_up_model()
def _init_cache_engine(self): def _init_cache_engine(self):
@@ -511,10 +512,9 @@ class NPUWorker(LocalOrDistributedWorkerBase):
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
expert_tensor_parallel_size = 1 expert_tensor_parallel_size = 1
if additional_config is not None and hasattr( if additional_config:
additional_config, "expert_tensor_parallel_size"): expert_tensor_parallel_size = additional_config.get(
expert_tensor_parallel_size = getattr( "expert_tensor_parallel_size", 1)
additional_config, "expert_tensor_parallel_size")
init_ascend_model_parallel(parallel_config.tensor_parallel_size, init_ascend_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
expert_tensor_parallel_size) expert_tensor_parallel_size)