bugfix for mtp in fullgraph (#3878)
### What this PR does / why we need it? bugfix for mtp in fullgraph ### Does this PR introduce _any_ user-facing change? no --------- Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token
|
||||
##### 1. Get token positions:
|
||||
First, determine which request each token belongs to: tokens 0–2 are assigned to **request_0**, tokens 3–4 to **request_1**, and tokens 5–9 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`.
|
||||
|
||||
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
|
||||
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
|
||||
|
||||
Note: there is more efficient way (using `request indices`) to create positions in actual code.
|
||||
|
||||
@@ -152,33 +152,33 @@ The KV cache block in the device memory is like:
|
||||
Let's say `K = max model len / block size = 6`, and we can get token `device block number`.
|
||||
|
||||
The workflow of achieving slot mapping:
|
||||
1. Get `block table indices` using `K`, `positions` and `request indices`.
|
||||
1. Get `block table indices` using `K`, `positions` and `request indices`.
|
||||
|
||||
Purpose: For each token, it could be used to select `device block number` from `block table`.
|
||||
|
||||
2. Get `device block number` using `block table indices`.
|
||||
2. Get `device block number` using `block table indices`.
|
||||
|
||||
Purpose: `device block number` indicates which device block each token belongs to.
|
||||
|
||||
3. Get `block offsets` using `positions` and `block size`.
|
||||
3. Get `block offsets` using `positions` and `block size`.
|
||||
|
||||
Purpose: `block offsets` indicates the offsets of each token within a block.
|
||||
|
||||
4. construct `slot mapping` using `device block number` and `block offsets`.
|
||||
4. construct `slot mapping` using `device block number` and `block offsets`.
|
||||
|
||||
Purpose: we can use `slot mapping` to store Token IDs into token slots.
|
||||
|
||||
Details:
|
||||
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
|
||||
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
|
||||
2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]`
|
||||
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
|
||||
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
|
||||
4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]`
|
||||
|
||||
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
|
||||
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
|
||||
|
||||
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
|
||||
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
|
||||
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
|
||||
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
|
||||
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
|
||||
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
|
||||
- `number of requests`: `3`
|
||||
- (**Request level**) `number of tokens`: `[3, 2, 5]`
|
||||
- `max query len`: `5`
|
||||
@@ -235,7 +235,7 @@ KV cache block in the device memory:
|
||||
1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]`
|
||||
2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]`
|
||||
3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]`
|
||||
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
|
||||
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
|
||||
|
||||
Scheduled token count:`[1, 1, 3]`
|
||||
- `query start location`: `[0, 1, 2, 5]`
|
||||
@@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]`
|
||||
|
||||
- `slot mapping`: `[5, 14, 13, 16, 17]`
|
||||
|
||||
- `attention mask`: `5 * 8`
|
||||
- `attention mask`: `5 * 8`
|
||||
|
||||
Each token has a `1 * 8` vector, and there are 5 scheduled tokens.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user