bugfix for mtp in fullgraph (#3878)

### What this PR does / why we need it?
bugfix for mtp in fullgraph

### Does this PR introduce _any_ user-facing change?
no

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-10-29 23:52:20 +08:00
committed by GitHub
parent 19f49ecb5f
commit d9249c968e
6 changed files with 58 additions and 38 deletions

View File

@@ -74,7 +74,7 @@ vLLM Ascend includes two branches: main and dev.
Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible. Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible.
### Maintenance branch and EOL ### Maintenance branch and EOL
The table below lists branch states. The table below lists branch states.
| Branch | Time Frame | Summary | | Branch | Time Frame | Summary |
| ----------------- | -------------------------------- | --------------------------------------------------------- | | ----------------- | -------------------------------- | --------------------------------------------------------- |

View File

@@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token
##### 1. Get token positions: ##### 1. Get token positions:
First, determine which request each token belongs to: tokens 02 are assigned to **request_0**, tokens 34 to **request_1**, and tokens 59 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`. First, determine which request each token belongs to: tokens 02 are assigned to **request_0**, tokens 34 to **request_1**, and tokens 59 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`.
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`). For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
Note: there is more efficient way (using `request indices`) to create positions in actual code. Note: there is more efficient way (using `request indices`) to create positions in actual code.
@@ -152,33 +152,33 @@ The KV cache block in the device memory is like:
Let's say `K = max model len / block size = 6`, and we can get token `device block number`. Let's say `K = max model len / block size = 6`, and we can get token `device block number`.
The workflow of achieving slot mapping: The workflow of achieving slot mapping:
1. Get `block table indices` using `K`, `positions` and `request indices`. 1. Get `block table indices` using `K`, `positions` and `request indices`.
Purpose: For each token, it could be used to select `device block number` from `block table`. Purpose: For each token, it could be used to select `device block number` from `block table`.
2. Get `device block number` using `block table indices`. 2. Get `device block number` using `block table indices`.
Purpose: `device block number` indicates which device block each token belongs to. Purpose: `device block number` indicates which device block each token belongs to.
3. Get `block offsets` using `positions` and `block size`. 3. Get `block offsets` using `positions` and `block size`.
Purpose: `block offsets` indicates the offsets of each token within a block. Purpose: `block offsets` indicates the offsets of each token within a block.
4. construct `slot mapping` using `device block number` and `block offsets`. 4. construct `slot mapping` using `device block number` and `block offsets`.
Purpose: we can use `slot mapping` to store Token IDs into token slots. Purpose: we can use `slot mapping` to store Token IDs into token slots.
Details: Details:
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`. 1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]` 2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]`
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`. 3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]` 4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]`
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`: (**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`. - (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`. - (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`. - (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
- `number of requests`: `3` - `number of requests`: `3`
- (**Request level**) `number of tokens`: `[3, 2, 5]` - (**Request level**) `number of tokens`: `[3, 2, 5]`
- `max query len`: `5` - `max query len`: `5`
@@ -235,7 +235,7 @@ KV cache block in the device memory:
1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]` 1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]`
2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]` 2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]`
3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]` 3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]`
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]` 4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
Scheduled token count:`[1, 1, 3]` Scheduled token count:`[1, 1, 3]`
- `query start location`: `[0, 1, 2, 5]` - `query start location`: `[0, 1, 2, 5]`
@@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]`
- `slot mapping`: `[5, 14, 13, 16, 17]` - `slot mapping`: `[5, 14, 13, 16, 17]`
- `attention mask`: `5 * 8` - `attention mask`: `5 * 8`
Each token has a `1 * 8` vector, and there are 5 scheduled tokens. Each token has a `1 * 8` vector, and there are 5 scheduled tokens.

View File

@@ -80,4 +80,4 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
| GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) | | GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |
| InternVL2.0/2.5/3.0<br>InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) | | InternVL2.0/2.5/3.0<br>InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |
| Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) | | Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |
| Ultravox | 🟡 | Need test | | Ultravox | 🟡 | Need test |

View File

@@ -263,6 +263,7 @@ class NPUPlatform(Platform):
**********************************************************************************\033[0m **********************************************************************************\033[0m
""" """
logger.warning(warning_message) logger.warning(warning_message)
update_aclgraph_sizes(vllm_config)
else: else:
logger.info( logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE", "%s cudagraph_mode is not support on NPU. falling back to NONE",

View File

@@ -314,6 +314,13 @@ def get_max_hidden_layers(hf_config) -> int:
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
"""Update ACL graph capture sizes based on hardware limitations""" """Update ACL graph capture sizes based on hardware limitations"""
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
return
# NOTE: Currently, we can only capture 1800 graphs at most, # NOTE: Currently, we can only capture 1800 graphs at most,
# due to the limitation of ACL graph. This number is bounded by # due to the limitation of ACL graph. This number is bounded by
# the number of streams, which is 2048, we save 248 streams # the number of streams, which is 2048, we save 248 streams
@@ -421,25 +428,43 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
vllm_config.model_config.architectures[0], num_hidden_layers, vllm_config.model_config.architectures[0], num_hidden_layers,
len(original_sizes)) len(original_sizes))
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario # default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than # the maximum size cudagraph_capture_sizes[0] should be greater or equal than
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode # (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
if vllm_config.speculative_config is not None and \ from vllm.config.compilation import CUDAGraphMode
vllm_config.speculative_config.num_speculative_tokens > 1: compilation_config = vllm_config.compilation_config
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
max_num_seqs = vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = num_speculative_tokens + 1
original_sizes, compilation_config.cudagraph_capture_sizes = \ max_num_seqs = vllm_config.scheduler_config.max_num_seqs
compilation_config.cudagraph_capture_sizes, None max_num_tokens = max_num_seqs * uniform_decode_query_len
assert len(original_sizes) > 0 original_sizes, compilation_config.cudagraph_capture_sizes = \
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: compilation_config.cudagraph_capture_sizes, None
enlarged_sizes = [(num_speculative_tokens + 1) * size assert len(original_sizes) > 0
for size in original_sizes]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes) if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
logger.info( not all(size % uniform_decode_query_len == 0 for size in original_sizes):
"Adjusted ACL graphs: %s%s for speculative decoding", enlarged_sizes = [
original_sizes, enlarged_sizes) size * uniform_decode_query_len for size in original_sizes
else: if max_num_tokens >= size >= uniform_decode_query_len
compilation_config.cudagraph_capture_sizes = original_sizes ]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info("Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
elif original_sizes[0] < max_num_tokens:
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info("Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
# TODO(wxy): Move to ops module # TODO(wxy): Move to ops module

View File

@@ -3529,14 +3529,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine(): aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list( compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes)) reversed(self.aclgraph_batch_sizes))
self._capture_aclgraphs( self._capture_aclgraphs(
compilation_cases=compilation_cases_decode, compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL, aclgraph_runtime_mode=CUDAGraphMode.FULL,