bugfix for mtp in fullgraph (#3878)

### What this PR does / why we need it?
bugfix for mtp in fullgraph

### Does this PR introduce _any_ user-facing change?
no

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-10-29 23:52:20 +08:00
committed by GitHub
parent 19f49ecb5f
commit d9249c968e
6 changed files with 58 additions and 38 deletions

View File

@@ -74,7 +74,7 @@ vLLM Ascend includes two branches: main and dev.
Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible.
### Maintenance branch and EOL
The table below lists branch states.
The table below lists branch states.
| Branch | Time Frame | Summary |
| ----------------- | -------------------------------- | --------------------------------------------------------- |

View File

@@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token
##### 1. Get token positions:
First, determine which request each token belongs to: tokens 02 are assigned to **request_0**, tokens 34 to **request_1**, and tokens 59 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`.
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
Note: there is more efficient way (using `request indices`) to create positions in actual code.
@@ -152,33 +152,33 @@ The KV cache block in the device memory is like:
Let's say `K = max model len / block size = 6`, and we can get token `device block number`.
The workflow of achieving slot mapping:
1. Get `block table indices` using `K`, `positions` and `request indices`.
1. Get `block table indices` using `K`, `positions` and `request indices`.
Purpose: For each token, it could be used to select `device block number` from `block table`.
2. Get `device block number` using `block table indices`.
2. Get `device block number` using `block table indices`.
Purpose: `device block number` indicates which device block each token belongs to.
3. Get `block offsets` using `positions` and `block size`.
3. Get `block offsets` using `positions` and `block size`.
Purpose: `block offsets` indicates the offsets of each token within a block.
4. construct `slot mapping` using `device block number` and `block offsets`.
4. construct `slot mapping` using `device block number` and `block offsets`.
Purpose: we can use `slot mapping` to store Token IDs into token slots.
Details:
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]`
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]`
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
- `number of requests`: `3`
- (**Request level**) `number of tokens`: `[3, 2, 5]`
- `max query len`: `5`
@@ -235,7 +235,7 @@ KV cache block in the device memory:
1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]`
2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]`
3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]`
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
Scheduled token count:`[1, 1, 3]`
- `query start location`: `[0, 1, 2, 5]`
@@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]`
- `slot mapping`: `[5, 14, 13, 16, 17]`
- `attention mask`: `5 * 8`
- `attention mask`: `5 * 8`
Each token has a `1 * 8` vector, and there are 5 scheduled tokens.

View File

@@ -80,4 +80,4 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
| GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |
| InternVL2.0/2.5/3.0<br>InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |
| Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |
| Ultravox | 🟡 | Need test |
| Ultravox | 🟡 | Need test |