bugfix for mtp in fullgraph (#3878)
### What this PR does / why we need it? bugfix for mtp in fullgraph ### Does this PR introduce _any_ user-facing change? no --------- Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -74,7 +74,7 @@ vLLM Ascend includes two branches: main and dev.
|
||||
Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible.
|
||||
|
||||
### Maintenance branch and EOL
|
||||
The table below lists branch states.
|
||||
The table below lists branch states.
|
||||
|
||||
| Branch | Time Frame | Summary |
|
||||
| ----------------- | -------------------------------- | --------------------------------------------------------- |
|
||||
|
||||
@@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token
|
||||
##### 1. Get token positions:
|
||||
First, determine which request each token belongs to: tokens 0–2 are assigned to **request_0**, tokens 3–4 to **request_1**, and tokens 5–9 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`.
|
||||
|
||||
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
|
||||
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
|
||||
|
||||
Note: there is more efficient way (using `request indices`) to create positions in actual code.
|
||||
|
||||
@@ -152,33 +152,33 @@ The KV cache block in the device memory is like:
|
||||
Let's say `K = max model len / block size = 6`, and we can get token `device block number`.
|
||||
|
||||
The workflow of achieving slot mapping:
|
||||
1. Get `block table indices` using `K`, `positions` and `request indices`.
|
||||
1. Get `block table indices` using `K`, `positions` and `request indices`.
|
||||
|
||||
Purpose: For each token, it could be used to select `device block number` from `block table`.
|
||||
|
||||
2. Get `device block number` using `block table indices`.
|
||||
2. Get `device block number` using `block table indices`.
|
||||
|
||||
Purpose: `device block number` indicates which device block each token belongs to.
|
||||
|
||||
3. Get `block offsets` using `positions` and `block size`.
|
||||
3. Get `block offsets` using `positions` and `block size`.
|
||||
|
||||
Purpose: `block offsets` indicates the offsets of each token within a block.
|
||||
|
||||
4. construct `slot mapping` using `device block number` and `block offsets`.
|
||||
4. construct `slot mapping` using `device block number` and `block offsets`.
|
||||
|
||||
Purpose: we can use `slot mapping` to store Token IDs into token slots.
|
||||
|
||||
Details:
|
||||
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
|
||||
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
|
||||
2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]`
|
||||
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
|
||||
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
|
||||
4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]`
|
||||
|
||||
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
|
||||
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
|
||||
|
||||
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
|
||||
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
|
||||
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
|
||||
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
|
||||
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
|
||||
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
|
||||
- `number of requests`: `3`
|
||||
- (**Request level**) `number of tokens`: `[3, 2, 5]`
|
||||
- `max query len`: `5`
|
||||
@@ -235,7 +235,7 @@ KV cache block in the device memory:
|
||||
1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]`
|
||||
2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]`
|
||||
3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]`
|
||||
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
|
||||
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
|
||||
|
||||
Scheduled token count:`[1, 1, 3]`
|
||||
- `query start location`: `[0, 1, 2, 5]`
|
||||
@@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]`
|
||||
|
||||
- `slot mapping`: `[5, 14, 13, 16, 17]`
|
||||
|
||||
- `attention mask`: `5 * 8`
|
||||
- `attention mask`: `5 * 8`
|
||||
|
||||
Each token has a `1 * 8` vector, and there are 5 scheduled tokens.
|
||||
|
||||
|
||||
@@ -80,4 +80,4 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
|
||||
| GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |
|
||||
| InternVL2.0/2.5/3.0<br>InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |
|
||||
| Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |
|
||||
| Ultravox | 🟡 | Need test |
|
||||
| Ultravox | 🟡 | Need test |
|
||||
|
||||
@@ -263,6 +263,7 @@ class NPUPlatform(Platform):
|
||||
**********************************************************************************\033[0m
|
||||
"""
|
||||
logger.warning(warning_message)
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
|
||||
@@ -314,6 +314,13 @@ def get_max_hidden_layers(hf_config) -> int:
|
||||
|
||||
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
"""Update ACL graph capture sizes based on hardware limitations"""
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
if vllm_config.speculative_config is not None and \
|
||||
vllm_config.speculative_config.num_speculative_tokens > 1:
|
||||
_update_spec_aclgraph_sizes(vllm_config)
|
||||
return
|
||||
|
||||
# NOTE: Currently, we can only capture 1800 graphs at most,
|
||||
# due to the limitation of ACL graph. This number is bounded by
|
||||
# the number of streams, which is 2048, we save 248 streams
|
||||
@@ -421,25 +428,43 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||
len(original_sizes))
|
||||
|
||||
if vllm_config.speculative_config is not None and \
|
||||
vllm_config.speculative_config.num_speculative_tokens > 1:
|
||||
_update_spec_aclgraph_sizes(vllm_config)
|
||||
|
||||
|
||||
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
|
||||
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
|
||||
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
|
||||
if vllm_config.speculative_config is not None and \
|
||||
vllm_config.speculative_config.num_speculative_tokens > 1:
|
||||
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
assert len(original_sizes) > 0
|
||||
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
|
||||
enlarged_sizes = [(num_speculative_tokens + 1) * size
|
||||
for size in original_sizes]
|
||||
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
|
||||
logger.info(
|
||||
"Adjusted ACL graphs: %s → %s for speculative decoding",
|
||||
original_sizes, enlarged_sizes)
|
||||
else:
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
compilation_config = vllm_config.compilation_config
|
||||
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
uniform_decode_query_len = num_speculative_tokens + 1
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_tokens = max_num_seqs * uniform_decode_query_len
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
assert len(original_sizes) > 0
|
||||
|
||||
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
|
||||
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
|
||||
enlarged_sizes = [
|
||||
size * uniform_decode_query_len for size in original_sizes
|
||||
if max_num_tokens >= size >= uniform_decode_query_len
|
||||
]
|
||||
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
|
||||
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
|
||||
original_sizes, enlarged_sizes)
|
||||
elif original_sizes[0] < max_num_tokens:
|
||||
enlarged_sizes = [
|
||||
size * uniform_decode_query_len for size in original_sizes
|
||||
]
|
||||
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
|
||||
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
|
||||
original_sizes, enlarged_sizes)
|
||||
else:
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
|
||||
@@ -3529,14 +3529,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
aclgraph_mode.separate_routine():
|
||||
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
||||
self.uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
||||
and x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
reversed(decode_cudagraph_batch_sizes))
|
||||
reversed(self.aclgraph_batch_sizes))
|
||||
self._capture_aclgraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
|
||||
Reference in New Issue
Block a user