[Misc] Refactor additional_config (#1029)
More and more config options are added to additional_config. This PR provide a new AscendConfig to manage these config options by an easier way to make code cleaner and readable. This PR also added the `additional_config` doc for users. Added the test_ascend_config.py to make sure the new AscendConfig works as expect. TODO: Add e2e test with torchair and deepseek once the CI resource is available. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -20,10 +20,11 @@
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
def allocate_kv_cache(
|
||||
self,
|
||||
@@ -36,8 +37,8 @@ def allocate_kv_cache(
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[Any] = []
|
||||
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config and additional_config.get("enable_graph_mode", False):
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
|
||||
@@ -64,6 +64,8 @@ from vllm.worker.model_runner_base import (
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
@@ -540,7 +542,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
}
|
||||
|
||||
# Add graph_pad_size here
|
||||
if self.runner.enable_graph_mode:
|
||||
if self.runner.torchair_graph_enabled:
|
||||
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
|
||||
seq_lens)
|
||||
else:
|
||||
@@ -603,7 +605,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
if self.runner.enable_graph_mode:
|
||||
if self.runner.torchair_graph_enabled:
|
||||
torch._dynamo.mark_static(input_tokens_tensor)
|
||||
torch._dynamo.mark_static(input_positions_tensor)
|
||||
torch._dynamo.mark_static(attn_metadata.block_tables)
|
||||
@@ -864,14 +866,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
self.max_batchsize_to_capture = \
|
||||
self.vllm_config.compilation_config.max_capture_size
|
||||
|
||||
self.enable_graph_mode = False
|
||||
self.use_cached_npu_graph = False
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
self.use_cached_npu_graph = additional_config.get(
|
||||
"use_cached_npu_graph", False)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
|
||||
self.has_inner_state = model_config.has_inner_state
|
||||
|
||||
@@ -971,7 +968,7 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.enable_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
@@ -1290,7 +1287,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
|
||||
assert model_input.attn_metadata is not None
|
||||
# TODO(zzzzwwjj): Do we need to do it every time?
|
||||
if self.enable_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
torch._dynamo.mark_static(model_input.input_tokens)
|
||||
torch._dynamo.mark_static(model_input.input_positions)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
|
||||
@@ -1305,7 +1302,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
virtual_engine = model_input.virtual_engine
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
if prefill_meta is None and self.enable_graph_mode:
|
||||
if prefill_meta is None and self.torchair_graph_enabled:
|
||||
model_executable = self.compile_model
|
||||
# Note: graph_batch_size value not same as GPU
|
||||
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
|
||||
@@ -1359,7 +1356,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
if self.enable_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs: Dict[str, Any] = {"inputs_embeds": None}
|
||||
else:
|
||||
model_kwargs = {}
|
||||
@@ -1377,7 +1374,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
self.vllm_config, virtual_engine):
|
||||
if model_input.attn_metadata is not None:
|
||||
model_input.attn_metadata.input_positions = model_input.input_positions
|
||||
if self.enable_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = kv_caches
|
||||
model_kwargs["attn_metadata"] = model_input.attn_metadata
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
@@ -1461,7 +1458,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||
elif self.enable_graph_mode:
|
||||
elif self.torchair_graph_enabled:
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
@@ -61,6 +61,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
@@ -137,13 +138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config and additional_config.get(
|
||||
"ascend_scheduler_config", None) is not None:
|
||||
self.use_v0_scheduler = True
|
||||
else:
|
||||
self.use_v0_scheduler = False
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.vllm_config.scheduler_config.max_num_seqs,
|
||||
(self.model_config.max_model_len + self.block_size - 1) //
|
||||
@@ -326,25 +320,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_mask_len, self.dtype)
|
||||
|
||||
self.sampler = Sampler()
|
||||
self.enable_torchair_graph_mode = False
|
||||
self.use_cached_npu_graph = False
|
||||
self.torchair_graph_batch_sizes = []
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config:
|
||||
self.enable_torchair_graph_mode = additional_config.get(
|
||||
"enable_graph_mode",
|
||||
False) and self.vllm_config.model_config.use_mla
|
||||
self.use_cached_npu_graph = additional_config.get(
|
||||
"use_cached_npu_graph", False)
|
||||
self.torchair_graph_batch_sizes = additional_config.get(
|
||||
"torchair_graph_batch_sizes", [])
|
||||
if not isinstance(self.torchair_graph_batch_sizes, list):
|
||||
logger.warning("torchair_graph_batch_sizes must be list[int]")
|
||||
self.torchair_graph_batch_sizes = []
|
||||
if len(self.torchair_graph_batch_sizes
|
||||
) == 0 and additional_config.get(
|
||||
"torchair_graph_batch_sizes_init", False):
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
|
||||
self.torchair_graph_use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
|
||||
@@ -628,13 +611,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
# splitfuse
|
||||
elif not self.use_v0_scheduler or self.chunked_prefill_enabled:
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
@@ -671,7 +655,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
||||
|
||||
# Add graph_pad_size here
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
|
||||
and not with_prefill):
|
||||
batch_size = len(seq_lens)
|
||||
if self.dp_size > 1:
|
||||
@@ -715,7 +699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
|
||||
if (envs_ascend.VLLM_ENABLE_MC2
|
||||
or self.enable_torchair_graph_mode) and not with_prefill:
|
||||
or self.torchair_graph_enabled) and not with_prefill:
|
||||
input_ids = self.input_ids[:padded_batch_size]
|
||||
positions = self.positions[:padded_batch_size]
|
||||
|
||||
@@ -724,10 +708,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
model_kwargs = {}
|
||||
if self.enable_torchair_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.enable_torchair_graph_mode and not with_prefill:
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
hidden_states = self.compile_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -1170,7 +1154,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with set_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if self.enable_torchair_graph_mode and not with_prefill:
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
# Only mark static while compiling
|
||||
@@ -1262,7 +1246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.enable_torchair_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
@@ -1339,7 +1323,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
if self.enable_torchair_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
@@ -1417,7 +1401,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
if self.enable_torchair_graph_mode:
|
||||
if self.torchair_graph_enabled:
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
logger.info(
|
||||
@@ -1449,10 +1433,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
|
||||
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
|
||||
CompilationLevel.PIECEWISE)
|
||||
logger.info("Skipping NPU graph capture for eager mode.")
|
||||
return
|
||||
end_time = time.perf_counter()
|
||||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||
|
||||
@@ -47,6 +47,7 @@ from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -75,6 +76,9 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
# Register ops when worker init.
|
||||
from vllm_ascend import ops # noqa: F401
|
||||
|
||||
# init ascend config
|
||||
init_ascend_config(vllm_config)
|
||||
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
# Try to import mindie_turbo to accelerate vLLM inference.
|
||||
try_register_lib(
|
||||
|
||||
@@ -42,6 +42,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib
|
||||
@@ -67,6 +68,8 @@ class NPUWorker(WorkerBase):
|
||||
from vllm_ascend import ops
|
||||
ops.register_dummy_fusion_op()
|
||||
_register_atb_extensions()
|
||||
# init ascend config
|
||||
init_ascend_config(vllm_config)
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
@@ -236,7 +239,7 @@ class NPUWorker(WorkerBase):
|
||||
if runner.dp_size > 1:
|
||||
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
|
||||
1, False)
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or runner.enable_torchair_graph_mode:
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or runner.torchair_graph_enabled:
|
||||
if not with_prefill:
|
||||
num_tokens = max_num_tokens
|
||||
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
|
||||
|
||||
Reference in New Issue
Block a user