bugfix for mtp in fullgraph (#3878)

### What this PR does / why we need it?
bugfix for mtp in fullgraph

### Does this PR introduce _any_ user-facing change?
no

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-10-29 23:52:20 +08:00
committed by GitHub
parent 19f49ecb5f
commit d9249c968e
6 changed files with 58 additions and 38 deletions

View File

@@ -74,7 +74,7 @@ vLLM Ascend includes two branches: main and dev.
Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible.
### Maintenance branch and EOL
The table below lists branch states.
The table below lists branch states.
| Branch | Time Frame | Summary |
| ----------------- | -------------------------------- | --------------------------------------------------------- |

View File

@@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token
##### 1. Get token positions:
First, determine which request each token belongs to: tokens 02 are assigned to **request_0**, tokens 34 to **request_1**, and tokens 59 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`.
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
Note: there is more efficient way (using `request indices`) to create positions in actual code.
@@ -152,33 +152,33 @@ The KV cache block in the device memory is like:
Let's say `K = max model len / block size = 6`, and we can get token `device block number`.
The workflow of achieving slot mapping:
1. Get `block table indices` using `K`, `positions` and `request indices`.
1. Get `block table indices` using `K`, `positions` and `request indices`.
Purpose: For each token, it could be used to select `device block number` from `block table`.
2. Get `device block number` using `block table indices`.
2. Get `device block number` using `block table indices`.
Purpose: `device block number` indicates which device block each token belongs to.
3. Get `block offsets` using `positions` and `block size`.
3. Get `block offsets` using `positions` and `block size`.
Purpose: `block offsets` indicates the offsets of each token within a block.
4. construct `slot mapping` using `device block number` and `block offsets`.
4. construct `slot mapping` using `device block number` and `block offsets`.
Purpose: we can use `slot mapping` to store Token IDs into token slots.
Details:
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]`
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]`
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
- `number of requests`: `3`
- (**Request level**) `number of tokens`: `[3, 2, 5]`
- `max query len`: `5`
@@ -235,7 +235,7 @@ KV cache block in the device memory:
1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]`
2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]`
3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]`
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
Scheduled token count:`[1, 1, 3]`
- `query start location`: `[0, 1, 2, 5]`
@@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]`
- `slot mapping`: `[5, 14, 13, 16, 17]`
- `attention mask`: `5 * 8`
- `attention mask`: `5 * 8`
Each token has a `1 * 8` vector, and there are 5 scheduled tokens.

View File

@@ -80,4 +80,4 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
| GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |
| InternVL2.0/2.5/3.0<br>InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |
| Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |
| Ultravox | 🟡 | Need test |
| Ultravox | 🟡 | Need test |

View File

@@ -263,6 +263,7 @@ class NPUPlatform(Platform):
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
update_aclgraph_sizes(vllm_config)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",

View File

@@ -314,6 +314,13 @@ def get_max_hidden_layers(hf_config) -> int:
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
"""Update ACL graph capture sizes based on hardware limitations"""
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
return
# NOTE: Currently, we can only capture 1800 graphs at most,
# due to the limitation of ACL graph. This number is bounded by
# the number of streams, which is 2048, we save 248 streams
@@ -421,25 +428,43 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
vllm_config.model_config.architectures[0], num_hidden_layers,
len(original_sizes))
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
_update_spec_aclgraph_sizes(vllm_config)
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
assert len(original_sizes) > 0
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
enlarged_sizes = [(num_speculative_tokens + 1) * size
for size in original_sizes]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info(
"Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
from vllm.config.compilation import CUDAGraphMode
compilation_config = vllm_config.compilation_config
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
uniform_decode_query_len = num_speculative_tokens + 1
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_tokens = max_num_seqs * uniform_decode_query_len
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
assert len(original_sizes) > 0
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
if max_num_tokens >= size >= uniform_decode_query_len
]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info("Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
elif original_sizes[0] < max_num_tokens:
enlarged_sizes = [
size * uniform_decode_query_len for size in original_sizes
]
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
logger.info("Adjusted ACL graphs: %s%s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
# TODO(wxy): Move to ops module

View File

@@ -3529,14 +3529,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
aclgraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
reversed(self.aclgraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,