[FEAT] Refactor spec decode to support efficient padded speculation (#3528)
### What this PR does / why we need it?
1. Refactor the file `mtp_proposer.py`, splits torchair related codes
into `mtp_torchair_proposer.py`
2. According to https://github.com/vllm-project/vllm/pull/24539,
implements padded speculative decoding as described in
https://github.com/vllm-project/vllm/issues/21984.
### Does this PR introduce _any_ user-facing change?
User can use `disable_padded_drafter_batch` to disable/enable padded
speculation, default is `False`.
offline example:
```
speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}
```
### How was this patch tested?
- [x] egaer with pad/unpad:
- [x] aclgraph with pad/unpad
- [x] torchair with pad/unpad
performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -133,6 +133,7 @@ from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_soc_version, is_310p,
|
||||
@@ -369,32 +370,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.model_config.max_model_len, self.dtype)
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.spec_attn_mask = None
|
||||
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
||||
MtpProposer]] = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
assert spec_token_num > 0
|
||||
self.decode_token_per_req = 1 + spec_token_num
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
diagonal=1).to(self.device)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = get_spec_decode_method(
|
||||
self.speculative_config.method, self.vllm_config,
|
||||
self.device, self)
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.rejection_sampler = AscendRejectionSampler()
|
||||
else:
|
||||
self.rejection_sampler = AscendRejectionSampler(
|
||||
self.sampler)
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
self._set_up_drafter()
|
||||
|
||||
# kv role
|
||||
self.is_kv_producer = False
|
||||
@@ -590,6 +566,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
|
||||
def _set_up_drafter(self):
|
||||
# Set up speculative decoding.
|
||||
self.spec_attn_mask = None
|
||||
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
||||
TorchairMtpProposer]] = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
assert spec_token_num > 0
|
||||
self.decode_token_per_req = 1 + spec_token_num
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
diagonal=1).to(self.device)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = self._get_drafter()
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.rejection_sampler = AscendRejectionSampler()
|
||||
else:
|
||||
self.rejection_sampler = AscendRejectionSampler(
|
||||
self.sampler)
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
self.num_discarded_requests = 0
|
||||
|
||||
def _get_drafter(self):
|
||||
return get_spec_decode_method(self.speculative_config.method,
|
||||
self.vllm_config, self.device, self)
|
||||
|
||||
def _may_pad_kv_consumer_num_seq(self):
|
||||
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
|
||||
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
|
||||
@@ -609,7 +618,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
# Use integer arithmetic for ceiling division.
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||
self.mc2_tokens_capacity: int = num_tokens_per_tp_rank * tp_size
|
||||
|
||||
def _make_buffer(self,
|
||||
*size: Union[int, torch.SymInt],
|
||||
@@ -1522,6 +1531,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Record the index of requests that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning
|
||||
num_tokens = [
|
||||
self.requests[r].num_tokens for r in self.input_batch.req_ids
|
||||
]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||
self.num_discarded_requests = len(discard_request_indices)
|
||||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||
discard_request_indices)
|
||||
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather
|
||||
# multi-modal outputs after that to ensure the correct order
|
||||
if self.is_multimodal_model:
|
||||
@@ -1615,7 +1638,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
self.spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode and self.need_accepted_tokens:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||
@@ -1676,7 +1699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=slot_mapping_size,
|
||||
@@ -1700,8 +1723,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
spec_decode_common_attn_metadata is None:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
self.spec_decode_common_attn_metadata is None:
|
||||
self.spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
common_prefix_len = 0
|
||||
@@ -1998,7 +2021,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
valid_sampled_token_ids: Union[torch.Tensor, list[list[int]]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
@@ -2255,6 +2278,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.apply_grammar_bitmask(
|
||||
scheduler_output, logits)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Sample"):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
@@ -2296,21 +2320,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
|
||||
discard_sampled_tokens_req_indices: list[int] = []
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
generator = self.input_batch.generators.get(int(i))
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
@@ -2346,10 +2361,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
|
||||
)
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
|
||||
@@ -2394,18 +2410,33 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert self.spec_decode_common_attn_metadata is not None
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
sampled_token_ids,
|
||||
sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
aux_hidden_states,
|
||||
)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Draft"):
|
||||
if self.speculative_config:
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
aux_hidden_states,
|
||||
)
|
||||
use_padded_batch_for_eagle = self.speculative_config and \
|
||||
self.speculative_config.method == "deepseek_mtp" and \
|
||||
not self.speculative_config.disable_padded_drafter_batch
|
||||
if use_padded_batch_for_eagle:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
if self.speculative_config and not use_padded_batch_for_eagle:
|
||||
# ngram and other speculative decoding methods use the sampled
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
@@ -92,8 +92,10 @@ class CachedRequestState:
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
else:
|
||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
Reference in New Issue
Block a user