[Misc] Refactor additional_config (#1029)

More and more config options are added to additional_config. This PR
provide a new AscendConfig to manage these config options by an easier
way to make code cleaner and readable.

 This PR also added the `additional_config` doc for users.

Added the test_ascend_config.py to make sure the new AscendConfig works
as expect.

TODO: Add e2e test with torchair and deepseek once the CI resource is
available.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-05 16:28:01 +08:00
committed by GitHub
parent 7737aaa40f
commit e1ab6d318e
23 changed files with 456 additions and 208 deletions

View File

@@ -61,6 +61,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -137,13 +138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
additional_config = vllm_config.additional_config
if additional_config and additional_config.get(
"ascend_scheduler_config", None) is not None:
self.use_v0_scheduler = True
else:
self.use_v0_scheduler = False
self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(self.model_config.max_model_len + self.block_size - 1) //
@@ -326,25 +320,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask_len, self.dtype)
self.sampler = Sampler()
self.enable_torchair_graph_mode = False
self.use_cached_npu_graph = False
self.torchair_graph_batch_sizes = []
additional_config = vllm_config.additional_config
if additional_config:
self.enable_torchair_graph_mode = additional_config.get(
"enable_graph_mode",
False) and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False)
self.torchair_graph_batch_sizes = additional_config.get(
"torchair_graph_batch_sizes", [])
if not isinstance(self.torchair_graph_batch_sizes, list):
logger.warning("torchair_graph_batch_sizes must be list[int]")
self.torchair_graph_batch_sizes = []
if len(self.torchair_graph_batch_sizes
) == 0 and additional_config.get(
"torchair_graph_batch_sizes_init", False):
self.init_torchair_graph_batch_sizes()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.torchair_graph_use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
if len(self.torchair_graph_batch_sizes) == 0:
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
@@ -628,13 +611,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
ascend_config = get_ascend_config()
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
# splitfuse
elif not self.use_v0_scheduler or self.chunked_prefill_enabled:
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
@@ -671,7 +655,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here
if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
and not with_prefill):
batch_size = len(seq_lens)
if self.dp_size > 1:
@@ -715,7 +699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
input_ids = self.input_ids[:num_input_tokens]
if (envs_ascend.VLLM_ENABLE_MC2
or self.enable_torchair_graph_mode) and not with_prefill:
or self.torchair_graph_enabled) and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
@@ -724,10 +708,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config,
num_tokens=num_input_tokens):
model_kwargs = {}
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and not with_prefill:
if self.torchair_graph_enabled and not with_prefill:
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
@@ -1170,7 +1154,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
if self.enable_torchair_graph_mode and not with_prefill:
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling
@@ -1262,7 +1246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
@@ -1339,7 +1323,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
@@ -1417,7 +1401,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
if self.enable_torchair_graph_mode:
if self.torchair_graph_enabled:
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)
logger.info(
@@ -1449,10 +1433,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
else:
logger.warning(
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
CompilationLevel.PIECEWISE)
logger.info("Skipping NPU graph capture for eager mode.")
return
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]