init v0.11.0rc0
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
@@ -31,7 +32,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
|
||||
update_aclgraph_sizes)
|
||||
update_aclgraph_sizes, vllm_version_is)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -128,11 +129,43 @@ class NPUPlatform(Platform):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||
if vllm_version_is("0.10.2"):
|
||||
structured_outputs_config = vllm_config.decoding_config
|
||||
else:
|
||||
structured_outputs_config = vllm_config.structured_outputs_config
|
||||
|
||||
if (model_config is not None and not model_config.use_mla
|
||||
and not scheduler_config.async_scheduling):
|
||||
logger.info(
|
||||
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
|
||||
"as the performance of operators supporting this feature "
|
||||
"functionality is currently suboptimal.")
|
||||
if not model_config.is_multimodal_model and \
|
||||
structured_outputs_config.backend == "auto" and \
|
||||
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
|
||||
not scheduler_config.send_delta_data and \
|
||||
scheduler_config.policy == "fcfs":
|
||||
ascend_scheduler_config.enabled = True
|
||||
chunked_prefill_enabled_in_ascend_scheduler = getattr(
|
||||
ascend_scheduler_config, "enable_chunked_prefill", False)
|
||||
if chunked_prefill_enabled_in_ascend_scheduler:
|
||||
logger.warning(
|
||||
"Chunked prefill feature is enabled in ascend_scheduler,"
|
||||
"but note that the operator supporting this feature "
|
||||
"would lead to performance degradation.")
|
||||
# In this situation, max_num_batched_tokens would have been rewritten.
|
||||
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
|
||||
if (scheduler_config.max_num_batched_tokens
|
||||
< scheduler_config.max_model_len
|
||||
and not chunked_prefill_enabled_in_ascend_scheduler):
|
||||
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len
|
||||
|
||||
kv_cache_dtype = vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None)
|
||||
if kv_cache_dtype is not None:
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
if model_config is None:
|
||||
logger.warning("Model config is missing. This may indicate "
|
||||
"that we are running a test case")
|
||||
@@ -148,23 +181,13 @@ class NPUPlatform(Platform):
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
elif compilation_config.level not in [
|
||||
if compilation_config.level not in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
|
||||
]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.level)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
logger.warning(
|
||||
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
@@ -185,18 +208,22 @@ class NPUPlatform(Platform):
|
||||
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
||||
delete_torchair_cache_file()
|
||||
|
||||
if parallel_config.distributed_executor_backend == "ray":
|
||||
logger.warning(
|
||||
"Ray distributed executor backend is not compatible with ACL Graph mode "
|
||||
"right now. Setting CUDAGraphMode to NONE")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
|
||||
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
|
||||
if not vllm_version_is("0.10.2"):
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
|
||||
# after MLA being supported
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
|
||||
compilation_config.cudagraph_mode
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
|
||||
and model_config.use_mla):
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
@@ -204,9 +231,28 @@ class NPUPlatform(Platform):
|
||||
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops.extend(
|
||||
["vllm.unified_ascend_attention_with_output"])
|
||||
compilation_config.splitting_ops.extend([
|
||||
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
|
||||
])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
compilation_config.use_inductor = False
|
||||
warning_message = """\033[91m
|
||||
**********************************************************************************
|
||||
* WARNING: You have enabled the *full graph* feature.
|
||||
* This is an early experimental stage and may involve various unknown issues.
|
||||
* A known problem is that capturing too many batch sizes can lead to OOM
|
||||
* (Out of Memory) errors or inference hangs. If you encounter such issues,
|
||||
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
|
||||
* batch size for graph capture.
|
||||
* For more details, please refer to:
|
||||
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
|
||||
**********************************************************************************\033[0m
|
||||
"""
|
||||
logger.warning(warning_message)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
@@ -215,7 +261,9 @@ class NPUPlatform(Platform):
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
|
||||
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||
@@ -223,6 +271,7 @@ class NPUPlatform(Platform):
|
||||
if cache_config:
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
||||
logger.warning(
|
||||
"If prefix caching is enabled, block size must be set to 128."
|
||||
@@ -242,12 +291,6 @@ class NPUPlatform(Platform):
|
||||
ascend_config.ascend_scheduler_config)
|
||||
vllm_config.scheduler_config = ascend_scheduler_config
|
||||
|
||||
if compilation_config.pass_config.enable_sequence_parallelism:
|
||||
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
|
||||
raise NotImplementedError(
|
||||
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls,
|
||||
selected_backend,
|
||||
@@ -257,27 +300,40 @@ class NPUPlatform(Platform):
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
use_sfa,
|
||||
has_sink=False):
|
||||
if not use_v1:
|
||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||
|
||||
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||
if use_mla and not use_sfa:
|
||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||
if use_mla and use_sfa:
|
||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||
|
||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||
# choose attention backend based on use_mla and use_torchair
|
||||
backend_map = {
|
||||
(True, True):
|
||||
(True, False, True):
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||
(True, False):
|
||||
(True, False, False):
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, True):
|
||||
(False, False, True):
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
(False, False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True, False):
|
||||
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
(True, True, True):
|
||||
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||
}
|
||||
return backend_map[(use_mla, use_torchair)]
|
||||
return backend_map[(use_mla, use_sfa, use_torchair)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
||||
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
@@ -343,3 +399,11 @@ class NPUPlatform(Platform):
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user