From d9249c968e531d58e5befd8ca7d07d8b91cdc096 Mon Sep 17 00:00:00 2001 From: zouyida2052 Date: Wed, 29 Oct 2025 23:52:20 +0800 Subject: [PATCH] bugfix for mtp in fullgraph (#3878) ### What this PR does / why we need it? bugfix for mtp in fullgraph ### Does this PR introduce _any_ user-facing change? no --------- Signed-off-by: zouyida2052 --- docs/source/community/versioning_policy.md | 2 +- .../ModelRunner_prepare_inputs.md | 26 ++++----- .../support_matrix/supported_models.md | 2 +- vllm_ascend/platform.py | 1 + vllm_ascend/utils.py | 57 +++++++++++++------ vllm_ascend/worker/model_runner_v1.py | 8 +-- 6 files changed, 58 insertions(+), 38 deletions(-) diff --git a/docs/source/community/versioning_policy.md b/docs/source/community/versioning_policy.md index 9cae449..1b09288 100644 --- a/docs/source/community/versioning_policy.md +++ b/docs/source/community/versioning_policy.md @@ -74,7 +74,7 @@ vLLM Ascend includes two branches: main and dev. Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible. ### Maintenance branch and EOL -The table below lists branch states. +The table below lists branch states. | Branch | Time Frame | Summary | | ----------------- | -------------------------------- | --------------------------------------------------------- | diff --git a/docs/source/developer_guide/feature_guide/ModelRunner_prepare_inputs.md b/docs/source/developer_guide/feature_guide/ModelRunner_prepare_inputs.md index bb2d877..04f8339 100644 --- a/docs/source/developer_guide/feature_guide/ModelRunner_prepare_inputs.md +++ b/docs/source/developer_guide/feature_guide/ModelRunner_prepare_inputs.md @@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token ##### 1. Get token positions: First, determine which request each token belongs to: tokens 0–2 are assigned to **request_0**, tokens 3–4 to **request_1**, and tokens 5–9 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`. -For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`). +For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`). Note: there is more efficient way (using `request indices`) to create positions in actual code. @@ -152,33 +152,33 @@ The KV cache block in the device memory is like: Let's say `K = max model len / block size = 6`, and we can get token `device block number`. The workflow of achieving slot mapping: -1. Get `block table indices` using `K`, `positions` and `request indices`. +1. Get `block table indices` using `K`, `positions` and `request indices`. Purpose: For each token, it could be used to select `device block number` from `block table`. -2. Get `device block number` using `block table indices`. +2. Get `device block number` using `block table indices`. Purpose: `device block number` indicates which device block each token belongs to. -3. Get `block offsets` using `positions` and `block size`. +3. Get `block offsets` using `positions` and `block size`. Purpose: `block offsets` indicates the offsets of each token within a block. -4. construct `slot mapping` using `device block number` and `block offsets`. +4. construct `slot mapping` using `device block number` and `block offsets`. Purpose: we can use `slot mapping` to store Token IDs into token slots. Details: -1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`. +1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`. 2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]` -3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`. +3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`. 4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]` -(**Request level**) As we know the scheduled token count is `[3, 2, 5]`: +(**Request level**) As we know the scheduled token count is `[3, 2, 5]`: -- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`. -- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`. -- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`. +- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`. +- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`. +- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`. - `number of requests`: `3` - (**Request level**) `number of tokens`: `[3, 2, 5]` - `max query len`: `5` @@ -235,7 +235,7 @@ KV cache block in the device memory: 1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]` 2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]` 3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]` -4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]` +4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]` Scheduled token count:`[1, 1, 3]` - `query start location`: `[0, 1, 2, 5]` @@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]` - `slot mapping`: `[5, 14, 13, 16, 17]` -- `attention mask`: `5 * 8` +- `attention mask`: `5 * 8` Each token has a `1 * 8` vector, and there are 5 scheduled tokens. diff --git a/docs/source/user_guide/support_matrix/supported_models.md b/docs/source/user_guide/support_matrix/supported_models.md index b258104..256f033 100644 --- a/docs/source/user_guide/support_matrix/supported_models.md +++ b/docs/source/user_guide/support_matrix/supported_models.md @@ -80,4 +80,4 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160 | GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) | | InternVL2.0/2.5/3.0
InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) | | Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) | -| Ultravox | 🟡 | Need test | \ No newline at end of file +| Ultravox | 🟡 | Need test | diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 449c3b0..dc6fb5c 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -263,6 +263,7 @@ class NPUPlatform(Platform): **********************************************************************************\033[0m """ logger.warning(warning_message) + update_aclgraph_sizes(vllm_config) else: logger.info( "%s cudagraph_mode is not support on NPU. falling back to NONE", diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d36acbe..34b98af 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -314,6 +314,13 @@ def get_max_hidden_layers(hf_config) -> int: def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: """Update ACL graph capture sizes based on hardware limitations""" + from vllm.config.compilation import CUDAGraphMode + if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + if vllm_config.speculative_config is not None and \ + vllm_config.speculative_config.num_speculative_tokens > 1: + _update_spec_aclgraph_sizes(vllm_config) + return + # NOTE: Currently, we can only capture 1800 graphs at most, # due to the limitation of ACL graph. This number is bounded by # the number of streams, which is 2048, we save 248 streams @@ -421,25 +428,43 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: vllm_config.model_config.architectures[0], num_hidden_layers, len(original_sizes)) + if vllm_config.speculative_config is not None and \ + vllm_config.speculative_config.num_speculative_tokens > 1: + _update_spec_aclgraph_sizes(vllm_config) + + +def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None: # default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario # the maximum size cudagraph_capture_sizes[0] should be greater or equal than # (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode - if vllm_config.speculative_config is not None and \ - vllm_config.speculative_config.num_speculative_tokens > 1: - num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens - max_num_seqs = vllm_config.scheduler_config.max_num_seqs - original_sizes, compilation_config.cudagraph_capture_sizes = \ - compilation_config.cudagraph_capture_sizes, None - assert len(original_sizes) > 0 - if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: - enlarged_sizes = [(num_speculative_tokens + 1) * size - for size in original_sizes] - compilation_config.init_with_cudagraph_sizes(enlarged_sizes) - logger.info( - "Adjusted ACL graphs: %s → %s for speculative decoding", - original_sizes, enlarged_sizes) - else: - compilation_config.cudagraph_capture_sizes = original_sizes + from vllm.config.compilation import CUDAGraphMode + compilation_config = vllm_config.compilation_config + num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + uniform_decode_query_len = num_speculative_tokens + 1 + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + max_num_tokens = max_num_seqs * uniform_decode_query_len + original_sizes, compilation_config.cudagraph_capture_sizes = \ + compilation_config.cudagraph_capture_sizes, None + assert len(original_sizes) > 0 + + if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \ + not all(size % uniform_decode_query_len == 0 for size in original_sizes): + enlarged_sizes = [ + size * uniform_decode_query_len for size in original_sizes + if max_num_tokens >= size >= uniform_decode_query_len + ] + compilation_config.init_with_cudagraph_sizes(enlarged_sizes) + logger.info("Adjusted ACL graphs: %s → %s for speculative decoding", + original_sizes, enlarged_sizes) + elif original_sizes[0] < max_num_tokens: + enlarged_sizes = [ + size * uniform_decode_query_len for size in original_sizes + ] + compilation_config.init_with_cudagraph_sizes(enlarged_sizes) + logger.info("Adjusted ACL graphs: %s → %s for speculative decoding", + original_sizes, enlarged_sizes) + else: + compilation_config.cudagraph_capture_sizes = original_sizes # TODO(wxy): Move to ops module diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0a4fd85..287f16d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3529,14 +3529,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ aclgraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len - decode_cudagraph_batch_sizes = [ - x for x in self.aclgraph_batch_sizes if x <= max_num_tokens - and x >= self.uniform_decode_query_len - ] compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + reversed(self.aclgraph_batch_sizes)) self._capture_aclgraphs( compilation_cases=compilation_cases_decode, aclgraph_runtime_mode=CUDAGraphMode.FULL,