[Spec Decode]clean up spec decode interface (#6947)
This pull request refactors the speculative decoding proposer interface
to align with upstream vLLM, removing the local `Proposer` interface and
renaming methods to `propose`.
This is the first step. In the future we should remove the class
register and just add few Ascend specified method once the arch in vLLM
is ready.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -74,8 +74,6 @@ from vllm.v1.sample.logits_processor import build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner
|
||||
@@ -109,9 +107,11 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
|
||||
from vllm_ascend.utils import (
|
||||
check_gdn_layer,
|
||||
enable_sp,
|
||||
@@ -402,9 +402,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
def _set_up_drafter(self):
|
||||
# Set up speculative decoding.
|
||||
self.drafter: NgramProposer | EagleProposer | MtpProposer | SuffixDecodingProposer | MedusaProposer | None = (
|
||||
None
|
||||
)
|
||||
self.drafter: (
|
||||
AscendNgramProposer
|
||||
| AscendEagleProposer
|
||||
| AscendMtpProposer
|
||||
| AscendSuffixDecodingProposer
|
||||
| AscendMedusaProposer
|
||||
| None
|
||||
) = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
if self.speculative_config:
|
||||
@@ -414,7 +419,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = self._get_drafter()
|
||||
if self.speculative_config.method == "eagle3":
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
assert isinstance(self.drafter, AscendEagleProposer)
|
||||
self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state
|
||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||
self.actual_seq_lengths_q = list(
|
||||
@@ -946,152 +951,134 @@ class NPUModelRunner(GPUModelRunner):
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: list[dict[str, Any]] | dict[str, Any],
|
||||
aux_hidden_states: torch.Tensor = None,
|
||||
sample_hidden_states: torch.Tensor = None,
|
||||
) -> list[list[int]] | None:
|
||||
if not self.drafter:
|
||||
# Speculative decoding is not enabled.
|
||||
draft_token_ids = None
|
||||
else:
|
||||
if self.speculative_config.method in ("suffix", "ngram"):
|
||||
draft_token_ids = self.drafter.generate_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
num_scheduled_tokens,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
)
|
||||
elif isinstance(self.drafter, MedusaProposer):
|
||||
draft_token_ids = self.drafter.generate_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
common_attn_metadata = spec_decode_common_attn_metadata
|
||||
sampled_token_ids = valid_sampled_token_ids
|
||||
elif isinstance(self.drafter, (AscendNgramProposer, AscendSuffixDecodingProposer)):
|
||||
draft_token_ids = self.drafter.propose(valid_sampled_token_ids)
|
||||
elif isinstance(self.drafter, AscendMedusaProposer):
|
||||
draft_token_ids = self.drafter.propose(
|
||||
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
common_attn_metadata = spec_decode_common_attn_metadata
|
||||
sampled_token_ids = valid_sampled_token_ids
|
||||
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
assert isinstance(sampled_token_ids, list), (
|
||||
"sampled_token_ids should be a python list whenpadded-batch is disabled."
|
||||
)
|
||||
assert self.drafter is not None
|
||||
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens
|
||||
)
|
||||
else:
|
||||
# When using padded-batch, the sampled_token_ids should be
|
||||
# the gpu tensor of sampled tokens for each request, of shape
|
||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||
# value -1.
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), (
|
||||
"sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled."
|
||||
)
|
||||
assert self.drafter is not None
|
||||
next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests,
|
||||
)
|
||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.use_cp:
|
||||
long_seq_metadata = self.long_seq_metadata # type: ignore
|
||||
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
|
||||
num_decode_reqs = self.pcp_manager.num_decode_reqs
|
||||
else:
|
||||
long_seq_metadata = None # type: ignore
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = self._get_positions(num_scheduled_tokens)
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
||||
target_positions = self._get_positions(num_scheduled_tokens)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
assert common_attn_metadata is not None
|
||||
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[
|
||||
: num_reqs + 1
|
||||
]
|
||||
assert common_attn_metadata is not None
|
||||
common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1]
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
||||
common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens
|
||||
)
|
||||
else:
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices, token_indices_to_sample = (
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = self._get_positions(token_indices)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
assert isinstance(sampled_token_ids, list), (
|
||||
"sampled_token_ids should be a python list whenpadded-batch is disabled."
|
||||
)
|
||||
assert self.drafter is not None
|
||||
draft_token_ids = self.drafter._propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
|
||||
# When using padded-batch, the sampled_token_ids should be
|
||||
# the gpu tensor of sampled tokens for each request, of shape
|
||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||
# value -1.
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), (
|
||||
"sampled_token_ids should be a torch.Tensor whenpadded-batch is enabled."
|
||||
)
|
||||
assert self.drafter is not None
|
||||
next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests,
|
||||
)
|
||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.use_cp:
|
||||
long_seq_metadata = self.long_seq_metadata # type: ignore
|
||||
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
|
||||
num_decode_reqs = self.pcp_manager.num_decode_reqs
|
||||
else:
|
||||
long_seq_metadata = None # type: ignore
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = query_start_loc_pcp_full[1 : num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = self._get_positions(num_scheduled_tokens)
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
||||
target_positions = self._get_positions(num_scheduled_tokens)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
assert common_attn_metadata is not None
|
||||
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] = query_start_loc_pcp_full_cpu[
|
||||
: num_reqs + 1
|
||||
]
|
||||
assert common_attn_metadata is not None
|
||||
common_attn_metadata.query_start_loc[: num_reqs + 1] = query_start_loc_pcp_full[: num_reqs + 1]
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
||||
common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens
|
||||
)
|
||||
else:
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = self._get_positions(token_indices)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
assert self.drafter is not None
|
||||
draft_token_ids = self.drafter._propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
@@ -1460,7 +1447,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
positions,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
aux_hidden_states,
|
||||
sample_hidden_states,
|
||||
)
|
||||
@@ -2088,7 +2074,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if kv_cache_gid > 0:
|
||||
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if isinstance(self.drafter, AscendEagleProposer):
|
||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||
spec_decode_common_attn_metadata = cm
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user