Update developer doc for v0.11.0-dev. This PR mainly picks developer doc from main to v0.11.0-dev. All related Feature work with 0.11.0 already. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
6.1 KiB
Multi Token Prediction (MTP)
Why We Need MTP
MTP boosts inference performance by parallelizing the prediction of multiple tokens, shifting from single-token to multi-token generation. This approach significantly increases generation throughput and achieves multiplicative acceleration in inference speed—all without compromising output quality.
How to Use MTP
To enable MTP for DeepSeek-V3 models, add the following parameter when starting the service:
--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
num_speculative_tokens: The number of speculative tokens which enable model to predict multiple tokens at once, if provided. It will default to the number in the draft model config if present, otherwise, it is required.disable_padded_drafter_batch: Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the MTP method of speculation, default is False.
How It Works
Module Architecture
vllm_ascend
├── sample
│ ├── rejection_sample.py
├── spec_decode
│ ├── mtp_proposer.py
└───────────
1. sample
- rejection_sample.py: During decoding, the main model processes the previous round’s output token and the predicted token together (computing 1+k tokens simultaneously). The first token is always correct, while the second token—referred to as the bonus token—is uncertain since it is derived from speculative prediction, thus We employ Greedy Strategy and Rejection Sampling Strategy to determine whether the bonus token should be accepted. The module structure consists of an
AscendRejectionSamplerclass with a forward method that implements the specific sampling logic.
rejection_sample.py
├── AscendRejectionSampler
│ ├── forward
2. spec_decode
This section encompasses the model preprocessing for spec-decode, primarily structured as follows: it includes loading the model, executing a dummy run, and generating token ids. These steps collectively form the model data construction and forward invocation for a single spec-decode operation.
- mtp_proposer.py: Configure vLLM-Ascend to use speculative decoding where proposals are generated by deepseek mtp layer.
mtp_proposer.py
├── Proposer
│ ├── load_model
│ ├── dummy_run
│ ├── generate_token_ids
│ ├── _prepare_inputs
│ ├── _propose
Algorithm
1. Reject_Sample
- Greedy Strategy
Verify whether the token generated by the main model matches the speculative token predicted by MTP in the previous round. If they match exactly, accept the bonus token; otherwise, reject it and any subsequent tokens derived from that speculation.
- Rejection Sampling Strategy
This method introduces stochasticity in rejection sampling.
For each draft token, acceptance is determined by verifying whether the inequality P_target / P_draft ≥ U holds, where P_target represents the probability assigned to the current draft token by the target model, P_draft denotes the probability assigned by the draft model, and U is a random number sampled uniformly from the interval [0, 1).
The decision logic for each draft token is as follows: if the inequality P_target / P_draft ≥ U holds, the draft token is accepted as output; conversely, if P_target / P_draft < U, the draft token is rejected.
When a draft token is rejected, a recovery sampling process is triggered where a "recovered token" is resampled from the adjusted probability distribution defined as Q = max(P_target - P_draft, 0). In the current MTP implementation, since P_draft is not provided and defaults to 1, the formulas simplify such that token acceptance occurs when P_target ≥ U, and the recovery distribution becomes Q = max(P_target - 1, 0).
2. Performance
If the bonus token is accepted, the MTP model performs inference for (num_speculative +1) tokens, including original main model output token and bonus token. If rejected, inference is performed for less token, determining on how many tokens accepted.
DFX
Method Validation
- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided.
def get_spec_decode_method(method,
vllm_config,
device,
runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'mtp':
return MtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
Integer Validation
- The current npu_fused_infer_attention_score operator only supports integers less than 16 per decode round. Therefore, the maximum supported value for MTP is 15. If a value greater than 15 is provided, the code will raise an error and alert the user.
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
Limitation
- Due to the fact that only a single layer of weights is exposed in DeepSeek's MTP, the accuracy and performance are not effectively guaranteed in scenarios where MTP > 1 (especially MTP ≥ 3). Moreover, due to current operator limitations, MTP supports a maximum of 15.
- In the fullgraph mode with MTP > 1, the capture size of each aclgraph must be an integer multiple of (num_speculative_tokens + 1).