[Feat] Refactor rejection sampler (#4975)
### What this PR does / why we need it?
Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- https://github.com/vllm-project/vllm/pull/19482
- https://github.com/vllm-project/vllm/pull/26060
- https://github.com/vllm-project/vllm/pull/29223
#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
https://github.com/vllm-project/vllm/pull/26060
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
https://github.com/vllm-project/vllm-ascend/pull/4893/)
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -69,9 +69,11 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
MambaSpec, MLAAttentionSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
ModelRunnerOutput,
|
||||
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
make_empty_encoder_model_runner_output)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
@@ -112,7 +114,6 @@ from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
@@ -424,7 +425,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = self._get_drafter()
|
||||
self.rejection_sampler = AscendRejectionSampler(self.sampler)
|
||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
@@ -1353,7 +1354,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
if self.pcp_size > 1:
|
||||
logits_indices = logits_indices_pcp
|
||||
metadata = SpecDecodeMetadata(
|
||||
return SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
@@ -1362,7 +1363,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
@@ -1719,145 +1719,13 @@ class NPUModelRunner(GPUModelRunner):
|
||||
grammar_output, logits)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Sample"):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:self.input_batch.num_reqs]
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[
|
||||
spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[
|
||||
spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
generator = self.input_batch.generators.get(int(i))
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||
req_id_to_index_output_copy = \
|
||||
self.input_batch.req_id_to_index.copy()
|
||||
|
||||
# NOTE: NPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens. It's a tensor.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens. It's a numpy array
|
||||
valid_sampled_token_ids, _ = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
|
||||
)
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
if self.num_spec_tokens <= 0:
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
|
||||
|
||||
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
||||
invalid_req_indices_set
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||
if i not in invalid_req_indices_set
|
||||
}
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [-1] * 1 if \
|
||||
req_idx not in invalid_req_indices_set else None
|
||||
else:
|
||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.model_config.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.model_config.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.is_token_ids[req_idx,
|
||||
start_idx:end_idx] = True
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert self.spec_decode_common_attn_metadata is not None
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
sampled_token_ids,
|
||||
sampling_metadata,
|
||||
self.input_batch.sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
@@ -1867,6 +1735,22 @@ class NPUModelRunner(GPUModelRunner):
|
||||
aux_hidden_states,
|
||||
)
|
||||
|
||||
(
|
||||
logprobs_lists,
|
||||
valid_sampled_token_ids,
|
||||
prompt_logprobs_dict,
|
||||
req_ids_output_copy,
|
||||
req_id_to_index_output_copy,
|
||||
invalid_req_indices,
|
||||
) = self._bookkeeping_sync(
|
||||
scheduler_output,
|
||||
sampler_output,
|
||||
logits,
|
||||
hidden_states,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
spec_decode_metadata,
|
||||
)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Draft"):
|
||||
if self.speculative_config:
|
||||
use_padded_batch_for_eagle = self.speculative_config and \
|
||||
@@ -1920,13 +1804,164 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.debugger.step()
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
vocab_size=self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
# overwrite _sample for lmhead_tp_enable and need_accepted_tokens
|
||||
def _sample(self, logits, spec_decode_metadata):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:self.input_batch.num_reqs]
|
||||
return self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||||
sampler_output = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
logits,
|
||||
sampling_metadata,
|
||||
)
|
||||
if self.need_accepted_tokens: # TODO remove this if
|
||||
self._update_states_after_model_execute(
|
||||
sampler_output.sampled_token_ids)
|
||||
return sampler_output
|
||||
|
||||
# TODO: remove this func after eagle_proposer is refactored and
|
||||
# _bookkeeping_sync is moved after propose_draft_token_ids
|
||||
def _bookkeeping_sync(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
sampler_output: SamplerOutput,
|
||||
logits: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||
) -> tuple[
|
||||
LogprobsLists | None,
|
||||
list[list[int]],
|
||||
dict[str, LogprobsTensors | None],
|
||||
list[str],
|
||||
dict[str, int],
|
||||
list[int],
|
||||
]:
|
||||
# TODO: implement PR 28597 from vllm
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
gen = self.input_batch.generators.get(int(i))
|
||||
if gen is not None:
|
||||
gen.set_offset(gen.get_offset() - 4)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
invalid_req_indices = []
|
||||
cu_num_tokens: list[int] | None = None
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
discard_sampled_tokens_req_indices,
|
||||
return_cu_num_tokens=logprobs_tensors is not None,
|
||||
)
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
|
||||
if self.num_spec_tokens <= 0:
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||
if i not in invalid_req_indices_set
|
||||
}
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [
|
||||
-1
|
||||
] if req_idx not in invalid_req_indices_set else None
|
||||
else:
|
||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||
|
||||
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
|
||||
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
assert end_idx <= self.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
|
||||
req_id = req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens)
|
||||
if not self.use_async_scheduling
|
||||
and logprobs_tensors is not None else None)
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
return (
|
||||
logprobs_lists,
|
||||
valid_sampled_token_ids,
|
||||
prompt_logprobs_dict,
|
||||
req_ids_output_copy,
|
||||
req_id_to_index_output_copy,
|
||||
invalid_req_indices,
|
||||
)
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
|
||||
Reference in New Issue
Block a user