### What this PR does / why we need it?
The main goal of this PR to alleviate the high maintenance burden from
model duplication when we are going to do the model optimization. Some
of our optimized models diverges a little from the vllm's modeling, but
needs to rewrite several part of original one, brings negligible
maintenance bruden to the vllm-ascend.In order to solve that, we propose
to leverage `torch.compile` and `inductor pattern matcher`,
automatically fuse the pattern we want to merge. For more details can
refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239
This pr integrates `AddRMSNorm` and the `Quant` operator, which can
improve the inference speed of models using `w8a8 `quantization.
### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config
### How was this patch tested?
```python
def main():
prompts = [
"The president of the United States is Mr.",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
# Create an LLM.
llm = LLM(
model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8",
# enforce_eager=True,
tensor_parallel_size=1,
trust_remote_code=True,
gpu_memory_utilization=0.7,
quantization="ascend",
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
```text
Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of'
```
- vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24
- vLLM main:
86e178f7c4
---------
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
372 lines
16 KiB
Python
372 lines
16 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Optional
|
|
|
|
from vllm.logger import logger
|
|
|
|
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
|
|
|
|
|
|
def _check_torchair_supported(model_type: str):
|
|
for supported_model in TORCHAIR_MODEL_LIST:
|
|
if supported_model in model_type.lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
class AscendConfig:
|
|
"""
|
|
Configuration Object for additional_config from vllm.configs.
|
|
"""
|
|
|
|
def __init__(self, vllm_config):
|
|
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
|
torchair_graph_config = additional_config.get("torchair_graph_config",
|
|
{})
|
|
|
|
self.torchair_graph_config = TorchairGraphConfig(
|
|
torchair_graph_config, vllm_config, additional_config)
|
|
|
|
ascend_compilation_config = additional_config.get(
|
|
"ascend_compilation_config", {})
|
|
self.ascend_compilation_config = AscendCompilationConfig(
|
|
**ascend_compilation_config)
|
|
|
|
ascend_scheduler_config = additional_config.get(
|
|
"ascend_scheduler_config", {})
|
|
self.ascend_scheduler_config = AscendSchedulerConfig(
|
|
ascend_scheduler_config)
|
|
|
|
# Dump / PrecisionDebugger configuration
|
|
dump_config_path = additional_config.get("dump_config", None)
|
|
self.dump_config = DumpConfig(dump_config_path)
|
|
|
|
weight_prefetch_config = additional_config.get(
|
|
"weight_prefetch_config", {})
|
|
self.weight_prefetch_config = WeightPrefetchConfig(
|
|
weight_prefetch_config)
|
|
|
|
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
|
self.expert_map_path = additional_config.get("expert_map_path", None)
|
|
self.eplb_policy_type = additional_config.get("eplb_policy_type", 1)
|
|
self.expert_map_record_path = additional_config.get(
|
|
"expert_map_record_path",
|
|
None) # Provide path to export expert map
|
|
self.init_redundancy_expert = additional_config.get(
|
|
"init_redundancy_expert", 0)
|
|
self.dynamic_eplb = additional_config.get("dynamic_eplb", False)
|
|
self.num_iterations_eplb_update = additional_config.get(
|
|
"num_iterations_eplb_update", 400)
|
|
self.gate_eplb = additional_config.get("gate_eplb", False)
|
|
self.num_wait_worker_iterations = additional_config.get(
|
|
"num_wait_worker_iterations", 30)
|
|
self.chunked_prefill_for_mla = additional_config.get(
|
|
"chunked_prefill_for_mla", False)
|
|
self.enable_shared_expert_dp = additional_config.get(
|
|
"enable_shared_expert_dp", False
|
|
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
|
if self.enable_shared_expert_dp:
|
|
from vllm_ascend.utils import enable_sp
|
|
assert enable_sp(vllm_config=vllm_config,
|
|
enable_shared_expert_dp=True)
|
|
self.multistream_overlap_shared_expert = additional_config.get(
|
|
"multistream_overlap_shared_expert", False)
|
|
self.recompute_scheduler_enable = additional_config.get(
|
|
"recompute_scheduler_enable", False)
|
|
self.lmhead_tensor_parallel_size = additional_config.get(
|
|
"lmhead_tensor_parallel_size", None)
|
|
if self.lmhead_tensor_parallel_size is not None:
|
|
logger.info(
|
|
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
|
|
)
|
|
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
|
raise AssertionError(
|
|
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
|
|
)
|
|
self.oproj_tensor_parallel_size = additional_config.get(
|
|
"oproj_tensor_parallel_size", None)
|
|
if self.oproj_tensor_parallel_size is not None:
|
|
logger.info(
|
|
f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario"
|
|
)
|
|
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
|
raise AssertionError(
|
|
"oproj_tensor_parallel_size is only supported in the pure DP scenario"
|
|
)
|
|
if vllm_config.model_config.enforce_eager is True:
|
|
raise AssertionError(
|
|
"oproj_tensor_parallel_size is only supported in graph mode"
|
|
)
|
|
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
|
raise AssertionError(
|
|
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
|
|
)
|
|
self.enable_cpu_binding = additional_config.get(
|
|
"enable_cpu_binding", False)
|
|
self.pd_tp_ratio = 1
|
|
self.pd_head_ratio = 1
|
|
self.num_head_replica = 1
|
|
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
|
|
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
|
"prefill", {"tp_size": 1})["tp_size"]
|
|
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
|
"decode", {"tp_size": 1})["tp_size"]
|
|
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
|
|
self.pd_tp_ratio = prefill_tp_size // decode_tp_size
|
|
if self.pd_tp_ratio > 1:
|
|
try:
|
|
# only support Qwen model now
|
|
# TODO: use a more robust method to get kv_head_num
|
|
num_kv_head = vllm_config.model_config.hf_config.num_key_value_heads
|
|
self.num_head_replica = prefill_tp_size // num_kv_head if prefill_tp_size >= num_kv_head else 1
|
|
prefill_tp_size = min(prefill_tp_size, num_kv_head)
|
|
decode_tp_size = min(decode_tp_size, num_kv_head)
|
|
self.pd_head_ratio = prefill_tp_size // decode_tp_size
|
|
except Exception:
|
|
raise AssertionError(
|
|
"Can not get num_key_value_heads from model_config")
|
|
|
|
if self.pd_tp_ratio == 0:
|
|
raise AssertionError(
|
|
"Only support P node tp size lagger then D node tp size")
|
|
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
|
"SLO_limits_for_dynamic_batch", -1)
|
|
from vllm_ascend.utils import \
|
|
get_flashcomm2_oproj_tp_size_and_validate_config
|
|
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
|
|
self, vllm_config)
|
|
|
|
|
|
class AscendCompilationConfig:
|
|
"""
|
|
Configuration for controlling the behavior of Ascend graph optimization.
|
|
|
|
This class provides a way to configure graph fusion optimizations.
|
|
These configurations directly impact the performance and behavior of models
|
|
deployed on Ascend platforms.
|
|
"""
|
|
|
|
def __init__(self, enable_quantization_fusion: bool = True, **kwargs):
|
|
"""
|
|
Initialize the configuration.
|
|
|
|
Args:
|
|
enable_quantization_fusion (bool): Whether to enable quantization fusion optimization.
|
|
When set to True, the system will optimize quantization-related operations,
|
|
reducing the number of quantization/dequantization nodes.
|
|
Default: True
|
|
|
|
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
|
|
"""
|
|
self.enable_quantization_fusion = enable_quantization_fusion
|
|
# Add more compilation related configs here as needed
|
|
|
|
|
|
class TorchairGraphConfig:
|
|
"""
|
|
Configuration Object for torchair_graph_config from additional_config
|
|
"""
|
|
|
|
def __init__(self, torchair_graph_config, vllm_config, additional_config):
|
|
self.enabled = torchair_graph_config.get("enabled", False)
|
|
self.mode = torchair_graph_config.get("mode", '')
|
|
self.use_cached_graph = torchair_graph_config.get(
|
|
"use_cached_graph", False)
|
|
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
|
|
"use_cached_kv_cache_bytes", False)
|
|
self.graph_batch_sizes = torchair_graph_config.get(
|
|
"graph_batch_sizes", [])
|
|
self.graph_batch_sizes_init = torchair_graph_config.get(
|
|
"graph_batch_sizes_init", False)
|
|
self.enable_multistream_mla = torchair_graph_config.get(
|
|
"enable_multistream_mla", False)
|
|
self.enable_view_optimize = torchair_graph_config.get(
|
|
"enable_view_optimize", True)
|
|
self.enable_frozen_parameter = torchair_graph_config.get(
|
|
"enable_frozen_parameter", True)
|
|
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
|
self.enable_super_kernel = torchair_graph_config.get(
|
|
"enable_super_kernel", False)
|
|
|
|
if not isinstance(self.graph_batch_sizes, list):
|
|
raise TypeError("graph_batch_sizes must be list[int]")
|
|
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
|
|
raise ValueError(
|
|
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
|
|
)
|
|
if not self.enabled:
|
|
if self.mode:
|
|
raise RuntimeError(
|
|
"mode is valid only when Torchair graph mode is enabled")
|
|
if self.use_cached_graph:
|
|
raise RuntimeError(
|
|
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.use_cached_kv_cache_bytes:
|
|
raise RuntimeError(
|
|
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.graph_batch_sizes:
|
|
raise RuntimeError(
|
|
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.graph_batch_sizes_init:
|
|
raise RuntimeError(
|
|
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.enable_multistream_mla:
|
|
raise RuntimeError(
|
|
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.enable_kv_nz:
|
|
raise RuntimeError(
|
|
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.enable_super_kernel:
|
|
raise RuntimeError(
|
|
"enable_super_kernel is valid only when Torchair graph mode is enabled"
|
|
)
|
|
if self.enable_super_kernel:
|
|
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
|
raise RuntimeError(
|
|
"enable_super_kernel is valid only when tensor_parallel_size is 1"
|
|
)
|
|
if not additional_config.get("multistream_overlap_shared_expert",
|
|
False):
|
|
raise RuntimeError(
|
|
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
|
|
)
|
|
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
|
raise RuntimeError(
|
|
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
|
)
|
|
|
|
|
|
class AscendSchedulerConfig:
|
|
"""
|
|
Configuration Object for ascend_scheduler_config from additional_config
|
|
"""
|
|
|
|
def __init__(self, ascend_scheduler_config: dict):
|
|
self.enabled = ascend_scheduler_config.get("enabled", False)
|
|
# Ascend scheduler is based on vllm v0 scheduler, so we should support
|
|
# all vllm v0 scheduler configs as well.
|
|
for k, v in ascend_scheduler_config.items():
|
|
if not hasattr(self, k):
|
|
setattr(self, k, v)
|
|
|
|
|
|
class DumpConfig:
|
|
"""
|
|
Configuration object for dump/PrecisionDebugger settings.
|
|
"""
|
|
|
|
def __init__(self, dump_config_path: Optional[str] = None):
|
|
# enable_dump is True when dump_cfg exists and config_path is not empty
|
|
self.enable_dump: bool = bool(dump_config_path)
|
|
# Path to msprobe config json; may be None.
|
|
self.config_path: Optional[str] = dump_config_path
|
|
|
|
|
|
class WeightPrefetchConfig:
|
|
"""
|
|
Configuration Object for weight_prefetch_config from additional_config
|
|
"""
|
|
|
|
prefetch_ratio: dict = {
|
|
"attn": {
|
|
"qkv": 1.0,
|
|
"o": 1.0,
|
|
},
|
|
"moe": {
|
|
"gate_up": 0.8
|
|
}
|
|
}
|
|
|
|
def __init__(self, weight_prefetch_config: dict):
|
|
self.enabled = weight_prefetch_config.get("enabled", False)
|
|
self.prefetch_ratio = weight_prefetch_config.get(
|
|
"prefetch_ratio", self.prefetch_ratio)
|
|
|
|
|
|
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
|
|
|
|
|
def init_ascend_config(vllm_config):
|
|
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
|
refresh = additional_config.get("refresh",
|
|
False) if additional_config else False
|
|
global _ASCEND_CONFIG
|
|
if _ASCEND_CONFIG is not None and not refresh:
|
|
return _ASCEND_CONFIG
|
|
_ASCEND_CONFIG = AscendConfig(vllm_config)
|
|
return _ASCEND_CONFIG
|
|
|
|
|
|
def clear_ascend_config():
|
|
global _ASCEND_CONFIG
|
|
_ASCEND_CONFIG = None
|
|
|
|
|
|
def get_ascend_config():
|
|
global _ASCEND_CONFIG
|
|
if _ASCEND_CONFIG is None:
|
|
raise RuntimeError(
|
|
"Ascend config is not initialized. Please call init_ascend_config first."
|
|
)
|
|
return _ASCEND_CONFIG
|
|
|
|
|
|
def check_ascend_config(vllm_config, enforce_eager):
|
|
ascend_config = get_ascend_config()
|
|
|
|
# for eager mode
|
|
if enforce_eager:
|
|
# torchair_graph cannot be enabled with eager mode.
|
|
if ascend_config.torchair_graph_config.enabled:
|
|
raise RuntimeError(
|
|
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
|
)
|
|
# for graph mode
|
|
else:
|
|
# torchair_graph case
|
|
if ascend_config.torchair_graph_config.enabled:
|
|
# torchair_graph is supported for deepseek/pangu/qwen model only.
|
|
if vllm_config.model_config:
|
|
model_type = vllm_config.model_config.hf_config.model_type
|
|
if not _check_torchair_supported(model_type):
|
|
raise NotImplementedError(
|
|
"Torchair graph mode only works with following model types:"
|
|
f"{TORCHAIR_MODEL_LIST}.")
|
|
if ascend_config.enable_shared_expert_dp:
|
|
logger.warning(
|
|
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
|
|
"it has been disabled automatically.")
|
|
# aclgraph case
|
|
else:
|
|
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
|
|
logger.info(
|
|
"Quantization fusion enabled! op fusion on quantization are expected. "
|
|
)
|
|
|
|
if vllm_config.model_config:
|
|
model_type = vllm_config.model_config.hf_config.model_type
|
|
if "qwen" not in model_type:
|
|
logger.warning(
|
|
"ACL Graph is currently experimental. Please "
|
|
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
|
" if you encourage any Error")
|