Files
xc-llm-ascend/vllm_ascend/ascend_config.py

290 lines
12 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from uuid import uuid4
from vllm.logger import logger
def check_kv_extra_config(vllm_config):
def _check(name: str, config: dict):
tp_key = "tp_size"
dp_key = "dp_size"
if tp_key in config:
config_tp = config[tp_key]
vllm_tp = vllm_config.parallel_config.tensor_parallel_size
if config_tp != vllm_tp:
raise ValueError(
f"KV transfer '{name}' config has a conflicting tensor parallel size. "
f"Expected {vllm_tp}, but got {config_tp}.")
if dp_key in config:
config_dp = config[dp_key]
vllm_dp = vllm_config.parallel_config.data_parallel_size
if config_dp != vllm_dp:
raise ValueError(
f"KV transfer '{name}' config has a conflicting data parallel size. "
f"Expected {vllm_dp}, but got {config_dp}.")
if vllm_config.kv_transfer_config.is_kv_producer:
_check(
"prefill",
vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {}))
if vllm_config.kv_transfer_config.is_kv_consumer:
_check(
"decode",
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
class AscendConfig:
"""
Configuration Object for additional_config from vllm.configs.
"""
def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
xlite_graph_config = additional_config.get("xlite_graph_config", {})
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
vllm_config)
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
ascend_compilation_config = additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config = AscendCompilationConfig(
**ascend_compilation_config)
# Dump / PrecisionDebugger configuration
dump_config_path = additional_config.get("dump_config", None)
self.dump_config = DumpConfig(dump_config_path)
weight_prefetch_config = additional_config.get(
"weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
self.expert_map_path = additional_config.get("expert_map_path", None)
self.eplb_policy_type = additional_config.get("eplb_policy_type", 1)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
self.expert_map_record_path = additional_config.get(
"expert_map_record_path",
None) # Provide path to export expert map
self.init_redundancy_expert = additional_config.get(
"init_redundancy_expert", 0)
self.dynamic_eplb = additional_config.get("dynamic_eplb", False)
self.num_iterations_eplb_update = additional_config.get(
"num_iterations_eplb_update", 400)
self.gate_eplb = additional_config.get("gate_eplb", False)
self.num_wait_worker_iterations = additional_config.get(
"num_wait_worker_iterations", 30)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198) ### What this PR does / why we need it? 1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to pure DP for shared experts, enabling more efficient execution. 2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by using ReduceScatter, made possible by pure DP sharding. 3.AllGather Postponed: Delayed to after QKV down projection to reduce synchronization impact during prefill. ### How was this patch tested? Adding ut case in `tests/ut/attention/test_mla_v1.py` #### How to run use parameter `--additional_config='{"enable_shared_expert_dp": true}'` ##### a.How to run eager mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true}' ##### b.How to run graph mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9edd1db02bc6dce6da503503a373657f3466a78b --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
2025-08-12 14:12:12 +08:00
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp",
False) and vllm_config.parallel_config.enable_expert_parallel
if self.enable_shared_expert_dp:
from vllm_ascend.utils import enable_sp
assert enable_sp(vllm_config=vllm_config,
enable_shared_expert_dp=True)
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.recompute_scheduler_enable = additional_config.get(
"recompute_scheduler_enable", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
if self.lmhead_tensor_parallel_size is not None:
logger.info(
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
)
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise AssertionError(
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
)
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
self.oproj_tensor_parallel_size = additional_config.get(
"oproj_tensor_parallel_size", None)
if self.oproj_tensor_parallel_size is not None:
logger.info(
f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario"
)
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in the pure DP scenario"
)
if vllm_config.model_config.enforce_eager is True:
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
raise AssertionError(
"oproj_tensor_parallel_size is only supported in graph mode"
)
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
self.enable_cpu_binding = additional_config.get(
"enable_cpu_binding", False)
if vllm_config.kv_transfer_config is not None:
check_kv_extra_config(vllm_config)
self.pd_tp_ratio = 1
self.pd_head_ratio = 1
self.num_head_replica = 1
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {"tp_size": 1})["tp_size"]
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {"tp_size": 1})["tp_size"]
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
self.pd_tp_ratio = prefill_tp_size // decode_tp_size
if self.pd_tp_ratio > 1:
try:
# only support Qwen model now
# TODO: use a more robust method to get kv_head_num
num_kv_head = vllm_config.model_config.hf_config.num_key_value_heads
self.num_head_replica = prefill_tp_size // num_kv_head if prefill_tp_size >= num_kv_head else 1
prefill_tp_size = min(prefill_tp_size, num_kv_head)
decode_tp_size = min(decode_tp_size, num_kv_head)
self.pd_head_ratio = prefill_tp_size // decode_tp_size
except Exception:
raise AssertionError(
"Can not get num_key_value_heads from model_config")
if self.pd_tp_ratio == 0:
raise AssertionError(
"Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
[Feat] Flashcomm2 use o_shared linear (#4188) ### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart <img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39" src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0" /> ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: zzh02232027 <zzh02232027@antgroup.com> Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
2025-12-11 12:43:04 +08:00
from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
self, vllm_config)
[Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao <chencangtao@huawei.com> Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao <chencangtao@huawei.com> Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat <wbigat@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
2025-12-10 20:48:05 +08:00
self.enable_npugraph_ex = additional_config.get(
"enable_npugraph_ex", False)
if self.enable_npugraph_ex:
raise NotImplementedError(
"This feature is still in the experiment and will be supported soon."
)
kv_cfg = vllm_config.kv_transfer_config
if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched",
False):
kv_cfg.engine_id = f"{kv_cfg.engine_id}-{uuid4().hex}"
kv_cfg._engine_id_patched = True
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
class AscendCompilationConfig:
"""
Configuration for controlling the behavior of Ascend graph optimization.
This class provides a way to configure graph fusion optimizations.
These configurations directly impact the performance and behavior of models
deployed on Ascend platforms.
"""
def __init__(self, fuse_norm_quant: bool = True, **kwargs):
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
"""
Initialize the configuration.
Args:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
Default: True
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
Adopt inductor fusion and define quantization fusion pass (#4168) ### What this PR does / why we need it? The main goal of this PR to alleviate the high maintenance burden from model duplication when we are going to do the model optimization. Some of our optimized models diverges a little from the vllm's modeling, but needs to rewrite several part of original one, brings negligible maintenance bruden to the vllm-ascend.In order to solve that, we propose to leverage `torch.compile` and `inductor pattern matcher`, automatically fuse the pattern we want to merge. For more details can refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239 This pr integrates `AddRMSNorm` and the `Quant` operator, which can improve the inference speed of models using `w8a8 `quantization. ### Does this PR introduce _any_ user-facing change? Yes, add new additional_config ### How was this patch tested? ```python def main(): prompts = [ "The president of the United States is Mr.", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8", # enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True, gpu_memory_utilization=0.7, quantization="ascend", ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` ```text Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden. \nB. Mr. Trump is not Mr. Biden. \nC. The president of the United States is not Mr. Trump. \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of' ``` - vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 - vLLM main: https://github.com/vllm-project/vllm/commit/86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24 --------- Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-04 10:29:48 +08:00
# Add more compilation related configs here as needed
class XliteGraphConfig:
"""
Configuration Object for xlite_graph_config from additional_config
"""
def __init__(self, xlite_graph_config, vllm_config):
self.enabled = xlite_graph_config.get("enabled", False)
self.full_mode = xlite_graph_config.get("full_mode", False)
if self.enabled:
if bool(vllm_config.speculative_config):
raise RuntimeError(
"Xlite graph mode is not compatible with speculative decoding. Please disable speculative decoding."
)
if vllm_config.parallel_config.pipeline_parallel_size > 1:
raise RuntimeError(
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1."
)
if vllm_config.cache_config.block_size != 128:
raise RuntimeError(
"Xlite graph mode is only compatible with block_size of 128. Please set block_size to 128."
)
class DumpConfig:
"""
Configuration object for dump/PrecisionDebugger settings.
"""
def __init__(self, dump_config_path: Optional[str] = None):
# enable_dump is True when dump_cfg exists and config_path is not empty
self.enable_dump: bool = bool(dump_config_path)
# Path to msprobe config json; may be None.
self.config_path: Optional[str] = dump_config_path
class WeightPrefetchConfig:
"""
Configuration Object for weight_prefetch_config from additional_config
"""
prefetch_ratio: dict = {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
"moe": {
"gate_up": 0.8
}
}
def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get(
"prefetch_ratio", self.prefetch_ratio)
_ASCEND_CONFIG: Optional[AscendConfig] = None
def init_ascend_config(vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
refresh = additional_config.get("refresh",
False) if additional_config else False
global _ASCEND_CONFIG
if _ASCEND_CONFIG is not None and not refresh:
return _ASCEND_CONFIG
_ASCEND_CONFIG = AscendConfig(vllm_config)
return _ASCEND_CONFIG
def clear_ascend_config():
global _ASCEND_CONFIG
_ASCEND_CONFIG = None
def get_ascend_config():
global _ASCEND_CONFIG
if _ASCEND_CONFIG is None:
raise RuntimeError(
"Ascend config is not initialized. Please call init_ascend_config first."
)
return _ASCEND_CONFIG