[Misc] Refactor additional_config (#1029)

More and more config options are added to additional_config. This PR
provide a new AscendConfig to manage these config options by an easier
way to make code cleaner and readable.

 This PR also added the `additional_config` doc for users.

Added the test_ascend_config.py to make sure the new AscendConfig works
as expect.

TODO: Add e2e test with torchair and deepseek once the CI resource is
available.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-05 16:28:01 +08:00
committed by GitHub
parent 7737aaa40f
commit e1ab6d318e
23 changed files with 456 additions and 208 deletions

View File

@@ -0,0 +1,138 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import vllm.envs as envs
from vllm.logger import logger
class AscendConfig:
"""
Configuration Object for additional_config from vllm.configs.
"""
def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
torchair_graph_config = additional_config.get("torchair_graph_config",
{})
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
ascend_scheduler_config = additional_config.get(
"ascend_scheduler_config", {})
self.ascend_scheduler_config = AscendSchedulerConfig(
ascend_scheduler_config)
self.expert_tensor_parallel_size = int(
additional_config.get("expert_tensor_parallel_size", 1))
class TorchairGraphConfig:
"""
Configuration Object for torchair_graph_config from additional_config
"""
def __init__(self, torchair_graph_config):
self.enabled = torchair_graph_config.get("enabled", False)
self.use_cached_graph = torchair_graph_config.get(
"use_cached_graph", False)
self.graph_batch_sizes = torchair_graph_config.get(
"graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False)
if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
raise ValueError(
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
)
class AscendSchedulerConfig:
"""
Configuration Object for ascend_scheduler_config from additional_config
"""
def __init__(self, ascend_scheduler_config: dict):
self.enabled = ascend_scheduler_config.get("enabled", False)
# Ascend scheduler is based on vllm v0 scheduler, so we should support
# all vllm v0 scheduler configs as well.
for k, v in ascend_scheduler_config.items():
if not hasattr(self, k):
setattr(self, k, v)
_ASCEND_CONFIG: Optional[AscendConfig] = None
def init_ascend_config(vllm_config):
global _ASCEND_CONFIG
if _ASCEND_CONFIG is not None:
return _ASCEND_CONFIG
_ASCEND_CONFIG = AscendConfig(vllm_config)
return _ASCEND_CONFIG
def clear_ascend_config():
global _ASCEND_CONFIG
_ASCEND_CONFIG = None
def get_ascend_config():
global _ASCEND_CONFIG
if _ASCEND_CONFIG is None:
raise RuntimeError(
"Ascend config is not initialized. Please call init_ascend_config first."
)
return _ASCEND_CONFIG
def check_ascend_config(vllm_config, enforce_eager):
ascend_config = get_ascend_config()
# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
if ascend_config.torchair_graph_config.enabled and not enforce_eager:
raise RuntimeError(
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
)
# torchair_graph only work with deepseek model and mla enabled.
if ascend_config.torchair_graph_config.enabled:
if envs.VLLM_MLA_DISABLE:
logger.warning(
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
"it has been disabled automatically.")
ascend_config.ascend_scheduler_config.enabled = False
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if "deepseek" not in model_type:
raise NotImplementedError(
"Torchair graph mode only works with deepseek model.")
# for V1 Engine, aclgraph doesn't work with deepseek model and only qwen model is well tested.
if envs.VLLM_USE_V1 and vllm_config.model_config is not None and not enforce_eager:
model_type = vllm_config.model_config.hf_config.model_type
if "deepseek" in model_type:
raise NotImplementedError(
"ACL Graph does not support deepseek. Please "
"try torchair graph mode to serve deepseek models on vllm-ascend."
" Or set `enforce_eager=True` to use eager mode.")
if "qwen" not in model_type:
logger.warning(
"ACL Graph is currently experimental. Please "
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
" if you encourage any Error")

View File

@@ -32,9 +32,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.config import get_current_vllm_config
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.cache import concat_and_cache_mla
from vllm_ascend.platform import CUSTOM_OP_ENABLED
from vllm_ascend.worker.model_runner import (
@@ -1002,11 +1002,8 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.w_kc = None
self.w_vc = None
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def exec_kv(
self,
@@ -1179,7 +1176,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.num_heads, -1)
# TODO: Replace the env with more flexible expressions
if self.enable_graph_mode:
if self.torchair_graph_enabled:
if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.num_prefills > 0:
slots = attn_metadata.slot_mapping
@@ -1230,7 +1227,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
if self.enable_graph_mode:
if self.torchair_graph_enabled:
# shape of query for npu graph mode should be:
# [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)

View File

@@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
@@ -443,20 +443,8 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# self.flash_attn_varlen_func = flash_attn_varlen_func
# if self.vllm_flash_attn_version is not None:
# self.flash_attn_varlen_func = \
# functools.partial(flash_attn_varlen_func,
# fa_version=self.vllm_flash_attn_version)
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
@@ -713,7 +701,7 @@ class AscendMLAImpl(MLAAttentionImpl):
if attn_metadata is None:
# Profiling run.
return output
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -776,7 +764,7 @@ class AscendMLAImpl(MLAAttentionImpl):
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.enable_graph_mode:
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1)
@@ -801,7 +789,7 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
if self.enable_graph_mode:
if self.torchair_graph_enabled:
if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
slots = attn_metadata.slot_mapping

View File

@@ -33,7 +33,7 @@ class AscendSchedulerConfig(SchedulerConfig):
def initialize_from_config(
cls,
vllm_scheduler_config: SchedulerConfig,
ascend_scheduler_config: dict,
ascend_scheduler_config,
):
scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name)
@@ -45,9 +45,10 @@ class AscendSchedulerConfig(SchedulerConfig):
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
# Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config
for k, v in ascend_scheduler_config.items():
scheduler_config[k] = v
# Override params in original SchedulerConfig with params in ascend_scheduler_config
for k, _ in scheduler_config.items():
if hasattr(ascend_scheduler_config, k):
scheduler_config[k] = getattr(ascend_scheduler_config, k)
return cls(**scheduler_config)
def __post_init__(self) -> None:

View File

@@ -34,8 +34,7 @@ import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
@@ -67,6 +66,7 @@ from vllm.model_executor.models.utils import (
from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor
@@ -214,11 +214,8 @@ class CustomDeepseekV2MoE(nn.Module):
self.params_dtype = torch.get_default_dtype()
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
@@ -248,7 +245,7 @@ class CustomDeepseekV2MoE(nn.Module):
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank]
elif not self.enable_graph_mode:
elif not self.torchair_graph_enabled:
num_padding_tokens = (self.tp_size -
num_tokens % self.tp_size) % self.tp_size
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
@@ -272,7 +269,7 @@ class CustomDeepseekV2MoE(nn.Module):
) * self.routed_scaling_factor
if self.tp_size > 1:
if self.enable_graph_mode:
if self.torchair_graph_enabled:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros(
[num_tokens, hidden_size],
@@ -423,11 +420,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
@@ -440,7 +435,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
if self.enable_graph_mode:
if self.torchair_graph_enabled:
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape

View File

@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
@@ -587,11 +588,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = ep_group.device_group
@@ -678,7 +676,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
elif self.enable_graph_mode or get_ep_group().world_size == 1:
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -772,11 +770,8 @@ class AscendFusedMoE(FusedMoE):
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
@@ -818,12 +813,6 @@ class AscendFusedMoE(FusedMoE):
self.ep_group = get_ep_group()
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -844,13 +833,13 @@ class AscendFusedMoE(FusedMoE):
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.enable_graph_mode:
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
elif self.enable_graph_mode and not is_prefill:
elif self.torchair_graph_enabled and not is_prefill:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
else:
@@ -878,14 +867,14 @@ class AscendFusedMoE(FusedMoE):
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.enable_graph_mode:
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
elif self.enable_graph_mode and not is_prefill:
elif self.torchair_graph_enabled and not is_prefill:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",

View File

@@ -24,6 +24,7 @@ import vllm.envs as envs
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
CUSTOM_OP_ENABLED = False
@@ -117,10 +118,12 @@ class NPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# initialize ascend config from vllm additional_config
ascend_config = init_ascend_config(vllm_config)
from vllm.config import CompilationLevel # noqa: E402
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
additional_config = vllm_config.additional_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
@@ -130,11 +133,8 @@ class NPUPlatform(Platform):
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
if (additional_config
and "expert_tensor_parallel_size" in additional_config
and not parallel_config.enable_expert_parallel):
parallel_config.expert_tensor_parallel_size = int(
additional_config["expert_tensor_parallel_size"])
if ascend_config.expert_tensor_parallel_size > 1 and not parallel_config.enable_expert_parallel:
parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size
# Calculate expert parallel size based on world size
parallel_config.expert_parallel_size = (
@@ -148,41 +148,7 @@ class NPUPlatform(Platform):
else:
enforce_eager = getattr(model_config, "enforce_eager", False)
if additional_config is not None:
enable_graph_mode = additional_config.get("enable_graph_mode",
False)
if enable_graph_mode:
if enforce_eager:
raise RuntimeError(
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
)
elif envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
logger.warning(
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
"it has been disabled automatically.")
additional_config["enable_graph_mode"] = False
if model_config:
model_type = model_config.hf_config.model_type
if "deepseek" not in model_type:
raise NotImplementedError(
"enable_graph_mode only works with deepseek model."
)
# Set compilation level to NO_COMPILATION to disable ACL Graph
compilation_config.level = CompilationLevel.NO_COMPILATION
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
model_type = model_config.hf_config.model_type
if "deepseek" in model_type:
raise NotImplementedError(
"ACL Graph does not support deepseek. Please "
"adopt additional_config={'enable_graph_mode': True} "
"to serve deepseek models with NPU graph mode on vllm-ascend with V1 engine."
" Or set `enforce_eager=True` to use eager mode.")
elif "qwen" not in model_type:
logger.warning(
"ACL Graph is currently experimental. Please "
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
" if you encourage any Error")
check_ascend_config(vllm_config, enforce_eager)
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
logger.info("Compilation disabled, using eager mode by default")
@@ -192,6 +158,11 @@ class NPUPlatform(Platform):
"NPU does not support %s compilation level. Setting level to NO_COMPILATION",
compilation_config.level)
compilation_config.level = CompilationLevel.NO_COMPILATION
elif ascend_config.torchair_graph_config.enabled:
logger.info(
"Torchair compilation enabled on NPU. Setting level to NO_COMPILATION"
)
compilation_config.level = CompilationLevel.NO_COMPILATION
else:
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
@@ -224,17 +195,15 @@ class NPUPlatform(Platform):
if envs.VLLM_USE_V1:
# Activate custom ops for v1.
compilation_config.custom_ops = ["all"]
# If ascend_scheduler_config exists in additional_config,
# extents original scheduler_config to use AscendScheduler.
if additional_config and additional_config.get(
"ascend_scheduler_config", None) is not None:
additional_scheduler_config = additional_config.get(
"ascend_scheduler_config")
# If ascend_scheduler_config is enabled,
# extents original scheduler_config to use AscendScheduler.
if ascend_config.ascend_scheduler_config.enabled:
from vllm_ascend.core.schedule_config import \
AscendSchedulerConfig
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
vllm_config.scheduler_config, additional_scheduler_config)
vllm_config.scheduler_config,
ascend_config.ascend_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config
@classmethod

View File

@@ -20,10 +20,10 @@ from typing import Any, Callable, Dict, Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import GroupCoordinator
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import dispose_tensor
@@ -509,11 +509,8 @@ class AscendW8A8DynamicFusedMoEMethod:
self.ep_group = get_ep_group()
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = self.ep_group.device_group
@@ -638,7 +635,7 @@ class AscendW8A8DynamicFusedMoEMethod:
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
elif self.enable_graph_mode or self.ep_group.world_size == 1:
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,

View File

@@ -20,10 +20,11 @@
from typing import Any, List
import torch
from vllm.config import get_current_vllm_config
from vllm.utils import is_pin_memory_available
from vllm.worker.cache_engine import CacheEngine
from vllm_ascend.ascend_config import get_ascend_config
def allocate_kv_cache(
self,
@@ -36,8 +37,8 @@ def allocate_kv_cache(
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[Any] = []
additional_config = get_current_vllm_config().additional_config
if additional_config and additional_config.get("enable_graph_mode", False):
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.

View File

@@ -64,6 +64,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -540,7 +542,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
}
# Add graph_pad_size here
if self.runner.enable_graph_mode:
if self.runner.torchair_graph_enabled:
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
seq_lens)
else:
@@ -603,7 +605,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
if self.runner.enable_graph_mode:
if self.runner.torchair_graph_enabled:
torch._dynamo.mark_static(input_tokens_tensor)
torch._dynamo.mark_static(input_positions_tensor)
torch._dynamo.mark_static(attn_metadata.block_tables)
@@ -864,14 +866,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size
self.enable_graph_mode = False
self.use_cached_npu_graph = False
additional_config = vllm_config.additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.has_inner_state = model_config.has_inner_state
@@ -971,7 +968,7 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.model = self.lora_manager.create_lora_manager(self.model)
# adapter torch compile with npu_backend
if self.enable_graph_mode:
if self.torchair_graph_enabled:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
@@ -1290,7 +1287,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
assert model_input.attn_metadata is not None
# TODO(zzzzwwjj): Do we need to do it every time?
if self.enable_graph_mode:
if self.torchair_graph_enabled:
torch._dynamo.mark_static(model_input.input_tokens)
torch._dynamo.mark_static(model_input.input_positions)
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
@@ -1305,7 +1302,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
virtual_engine = model_input.virtual_engine
prefill_meta = model_input.attn_metadata.prefill_metadata
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and self.enable_graph_mode:
if prefill_meta is None and self.torchair_graph_enabled:
model_executable = self.compile_model
# Note: graph_batch_size value not same as GPU
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
@@ -1359,7 +1356,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if self.enable_graph_mode:
if self.torchair_graph_enabled:
model_kwargs: Dict[str, Any] = {"inputs_embeds": None}
else:
model_kwargs = {}
@@ -1377,7 +1374,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.vllm_config, virtual_engine):
if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions
if self.enable_graph_mode:
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = kv_caches
model_kwargs["attn_metadata"] = model_input.attn_metadata
hidden_or_intermediate_states = model_executable(
@@ -1461,7 +1458,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif self.enable_graph_mode:
elif self.torchair_graph_enabled:
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states

View File

@@ -61,6 +61,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -137,13 +138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
additional_config = vllm_config.additional_config
if additional_config and additional_config.get(
"ascend_scheduler_config", None) is not None:
self.use_v0_scheduler = True
else:
self.use_v0_scheduler = False
self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(self.model_config.max_model_len + self.block_size - 1) //
@@ -326,25 +320,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask_len, self.dtype)
self.sampler = Sampler()
self.enable_torchair_graph_mode = False
self.use_cached_npu_graph = False
self.torchair_graph_batch_sizes = []
additional_config = vllm_config.additional_config
if additional_config:
self.enable_torchair_graph_mode = additional_config.get(
"enable_graph_mode",
False) and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False)
self.torchair_graph_batch_sizes = additional_config.get(
"torchair_graph_batch_sizes", [])
if not isinstance(self.torchair_graph_batch_sizes, list):
logger.warning("torchair_graph_batch_sizes must be list[int]")
self.torchair_graph_batch_sizes = []
if len(self.torchair_graph_batch_sizes
) == 0 and additional_config.get(
"torchair_graph_batch_sizes_init", False):
self.init_torchair_graph_batch_sizes()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.torchair_graph_use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
if len(self.torchair_graph_batch_sizes) == 0:
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
@@ -628,13 +611,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
ascend_config = get_ascend_config()
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
# splitfuse
elif not self.use_v0_scheduler or self.chunked_prefill_enabled:
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
@@ -671,7 +655,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here
if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
and not with_prefill):
batch_size = len(seq_lens)
if self.dp_size > 1:
@@ -715,7 +699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
input_ids = self.input_ids[:num_input_tokens]
if (envs_ascend.VLLM_ENABLE_MC2
or self.enable_torchair_graph_mode) and not with_prefill:
or self.torchair_graph_enabled) and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
@@ -724,10 +708,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config,
num_tokens=num_input_tokens):
model_kwargs = {}
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and not with_prefill:
if self.torchair_graph_enabled and not with_prefill:
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
@@ -1170,7 +1154,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
if self.enable_torchair_graph_mode and not with_prefill:
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling
@@ -1262,7 +1246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
@@ -1339,7 +1323,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
@@ -1417,7 +1401,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)
logger.info(
@@ -1449,10 +1433,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
else:
logger.warning(
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
CompilationLevel.PIECEWISE)
logger.info("Skipping NPU graph capture for eager mode.")
return
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]

View File

@@ -47,6 +47,7 @@ from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
@@ -75,6 +76,9 @@ class NPUWorker(LocalOrDistributedWorkerBase):
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
# init ascend config
init_ascend_config(vllm_config)
WorkerBase.__init__(self, vllm_config=vllm_config)
# Try to import mindie_turbo to accelerate vLLM inference.
try_register_lib(

View File

@@ -42,6 +42,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.worker_base import WorkerBase
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib
@@ -67,6 +68,8 @@ class NPUWorker(WorkerBase):
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
# init ascend config
init_ascend_config(vllm_config)
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
@@ -236,7 +239,7 @@ class NPUWorker(WorkerBase):
if runner.dp_size > 1:
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
1, False)
if envs_ascend.VLLM_ENABLE_MC2 or runner.enable_torchair_graph_mode:
if envs_ascend.VLLM_ENABLE_MC2 or runner.torchair_graph_enabled:
if not with_prefill:
num_tokens = max_num_tokens
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)