From 9e24bdd44c2a71f61002e60a4dad8095f3af652d Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Tue, 16 Dec 2025 11:32:26 +0800 Subject: [PATCH] [Feat] Refactor rejection sampler (#4975) ### What this PR does / why we need it? Currently, we are using `AscendRejctionSampler` that extends from `RejctionSampler` in spec decoding. `AscendRejctionSampler` override `forward` of `RejctionSampler`, only aming to replace `rejection_sample` func. This causes a lot of code of `RejctionSampler` cannot be reused, for example: - https://github.com/vllm-project/vllm/pull/19482 - https://github.com/vllm-project/vllm/pull/26060 - https://github.com/vllm-project/vllm/pull/29223 #### Proposed Change: - Delete `AscendRejctionSampler` and use `RejctionSampler` directly in model runner. - Patch `RejctionSampler.expand_batch_to_tokens` and `RejctionSampler.rejection_sample`, maybe a better way is to make them as custom ops. - Modify `NPUModelRunner` following https://github.com/vllm-project/vllm/pull/26060 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - [x] test logits processor for spec decoding - [x] test logprobs for spec decoding - [x] test logprobs for spec decoding + async shcheduling (test with https://github.com/vllm-project/vllm-ascend/pull/4893/) - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: realliujiaxu --- .../spec_decode_v1/test_v1_spec_decode.py | 54 +++ vllm_ascend/patch/__init__.py | 20 +- vllm_ascend/patch/worker/__init__.py | 1 + .../patch/worker/patch_rejection_sampler.py | 11 + vllm_ascend/sample/rejection_sampler.py | 95 +----- vllm_ascend/worker/model_runner_v1.py | 315 ++++++++++-------- 6 files changed, 260 insertions(+), 236 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_rejection_sampler.py diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 3d7c5453..f207c64d 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import math import os import random from typing import Any @@ -239,3 +240,56 @@ def test_suffix_acceptance( # Heuristic: expect at least 80% acceptance rate at the end. assert last_accept_rate > 0.60 + + +@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"]) +def test_eagle_logprobs( + model_name: str, + use_eagle3: bool, +): + prompt = {"role": "user", "content": "Hello world " * 10} + sampling_params = SamplingParams(temperature=0, + logprobs=1, + max_tokens=10, + ignore_eos=False) + + ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False) + ref_outputs = ref_llm.chat([prompt], sampling_params) + ref_logprobs = [] + for output in ref_outputs[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + ref_logprobs.append(logprobs[token_id]) + del ref_llm + + spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() + with VllmRunner( + model_name, + max_num_seqs=1, + max_num_batched_tokens=2048, + gpu_memory_utilization=0.6, + speculative_config={ + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, + "num_speculative_tokens": 2, + "max_model_len": 128, + }, + max_model_len=128, + enforce_eager=False, + ) as runner: + spec_outputs = runner.model.chat([prompt], sampling_params) + + # Collect logprobs outputs from spec decode LLM. + spec_logprobs = [] + for output in spec_outputs[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + spec_logprobs.append(logprobs[token_id]) + + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): + assert math.isclose(ref_logprob.logprob, + spec_logprob.logprob, + rel_tol=5e-2, + abs_tol=1e-1) + assert ref_logprob.rank == spec_logprob.rank + assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 092e2ce5..33da508f 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -228,7 +228,7 @@ # Future Plan: # Remove this patch when the bug is fixed. # -# ** File: worker/patch_qwen3_next_mtp.py** +# ** 11. File: worker/patch_qwen3_next_mtp.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.worker.utils.bind_kv_cache` # Why: @@ -241,7 +241,7 @@ # Future Plan: # Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu. # -# ** File: worker/patch_module.py** +# ** 12. File: worker/patch_module.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort` # Why: @@ -257,3 +257,19 @@ # Remove this patch when bool is supported in 'torch.argsort' func of npu. # Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable. # +# ** 13. File: worker/patch_rejection_sampler.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.sample.rejection_sampler` +# Why: +# - some functions from `rejection_sampler` are not supported or slow on npu. +# How: +# - add npu_top_k_top_p to 'apply_sampling_constraints' func +# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample` +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/874 +# https://github.com/vllm-project/vllm/pull/4849 +# Future Plan: +# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler +# to override them, then delete the patch file `worker/patch_rejection_sampler.py`. +# 2. make these functions as costom op, then remove AscendRejectionSampler +# \ No newline at end of file diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 45e37a5d..07a77f7d 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa import vllm_ascend.patch.worker.patch_qwen3_vl # noqa import vllm_ascend.patch.worker.patch_rope # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa +import vllm_ascend.patch.worker.patch_rejection_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_rejection_sampler.py b/vllm_ascend/patch/worker/patch_rejection_sampler.py new file mode 100644 index 00000000..f94fee60 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_rejection_sampler.py @@ -0,0 +1,11 @@ +import vllm.v1.sample.rejection_sampler as rs + +from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints, + expand_batch_to_tokens, + rejection_sample) + +# TODO: delete this patch after apply_sampling_constraints and rejection_sample +# are extracted to as class func of RejectionSampler +rs.apply_sampling_constraints = apply_sampling_constraints +rs.rejection_sample = rejection_sample +rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index c1ef10db..4ab03a77 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -2,15 +2,11 @@ from typing import Optional import torch -import torch.nn as nn import torch_npu -import vllm.v1.sample.rejection_sampler as rs from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.sample.rejection_sampler import (RejectionSampler, - generate_uniform_probs) -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.sample.rejection_sampler import generate_uniform_probs from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type @@ -21,92 +17,6 @@ GREEDY_TEMPERATURE = -1 MAX_SPEC_LEN = 32 -class AscendRejectionSampler(RejectionSampler, nn.Module): - """ - The implementation strictly follows the algorithm described in - https://arxiv.org/abs/2211.17192. - However, we want to clarify the terminology used in the implementation: - accepted tokens: tokens that are accepted based on the relationship - between the "raw" draft and target probabilities. - recovered tokens: tokens that are sampled based on the adjusted probability - distribution, which is derived from both the draft and target - probabilities. - bonus tokens: - If all proposed tokens are accepted, the bonus token is added to the - end of the sequence. The bonus token is only sampled from the target - probabilities. We pass in the bonus tokens instead of sampling them - in the rejection sampler to allow for more flexibility in the - sampling process. For example, we can use top_p, top_k sampling for - bonus tokens, while spec decode does not support these sampling - strategies. - output tokens: - Tokens are finally generated with the rejection sampler. - output tokens = accepted tokens + recovered tokens + bonus tokens - """ - - def forward( - self, - metadata: SpecDecodeMetadata, - # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], - # [num_tokens, vocab_size] - target_logits: torch.Tensor, - # [batch_size, 1] - bonus_token_ids: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - ''' - Args: - metadata: - Metadata for spec decoding. - draft_probs (Optional[torch.Tensor]): - Probability distribution for the draft tokens. Shape is - [num_tokens, vocab_size]. Can be None if probabilities are - not provided, which is the case for ngram spec decode. - target_logits (torch.Tensor): - Target model's logits probability distribution. - Shape is [num_tokens, vocab_size]. Here, probabilities from - different requests are flattened into a single tensor because - this is the shape of the output logits. - NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids_tensor (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling - process such as top_p, top_k sampling. - sampling_metadata (SamplingMetadata): - Additional metadata needed for sampling, such as temperature, - top-k/top-p parameters, or other relevant information. - Returns: - output_token_ids (torch.Tensor): - A tensor containing the final output token IDs. - ''' - assert metadata.max_spec_len <= MAX_SPEC_LEN - # [num_tokens, vocab_size] - # NOTE(woosuk): `target_logits` can be updated in place inside the - # `compute_probs` function. - target_logits = apply_sampling_constraints( - target_logits, - metadata.cu_num_draft_tokens, - sampling_metadata, - ) - target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) - - output_token_ids = rejection_sample( - metadata.draft_token_ids, - metadata.num_draft_tokens, - metadata.max_spec_len, - metadata.cu_num_draft_tokens, - draft_probs, - target_probs, - bonus_token_ids, - sampling_metadata, - ) - return output_token_ids - - def apply_sampling_constraints( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] @@ -844,6 +754,3 @@ def sample_recovered_tokens_kernel( tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) - - -rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 77c53cf0..151ef980 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,9 +69,11 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - ModelRunnerOutput, + LogprobsLists, LogprobsTensors, ModelRunnerOutput, + SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer @@ -112,7 +114,6 @@ from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs -from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer @@ -424,7 +425,7 @@ class NPUModelRunner(GPUModelRunner): ) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() - self.rejection_sampler = AscendRejectionSampler(self.sampler) + self.rejection_sampler = RejectionSampler(self.sampler) self.actual_seq_lengths_q = list( range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req)) @@ -1353,7 +1354,7 @@ class NPUModelRunner(GPUModelRunner): draft_token_ids = draft_token_ids[target_logits_indices + 1] if self.pcp_size > 1: logits_indices = logits_indices_pcp - metadata = SpecDecodeMetadata( + return SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, @@ -1362,7 +1363,6 @@ class NPUModelRunner(GPUModelRunner): bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) - return metadata def propose_draft_token_ids( self, @@ -1719,145 +1719,13 @@ class NPUModelRunner(GPUModelRunner): grammar_output, logits) with ProfileExecuteDuration().capture_async("Sample"): - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata - if spec_decode_metadata is None: - if lmhead_tp_enable() and logits is not None: - logits = logits[:self.input_batch.num_reqs] - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - if lmhead_tp_enable() and logits is not None: - logits = logits[:len(spec_decode_metadata.logits_indices)] - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[ - spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[ - spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - if self.need_accepted_tokens: - self._update_states_after_model_execute(output_token_ids) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] - for i in discard_sampled_tokens_req_indices: - generator = self.input_batch.generators.get(int(i)) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # Copy some objects so they don't get modified after returning. - # This is important when using async scheduling. - req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:scheduler_output.total_num_scheduled_tokens], - scheduler_output.num_scheduled_tokens, - ) - - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] - sampled_token_ids = sampler_output.sampled_token_ids - - if not self.use_async_scheduling: - # Get the valid generated tokens. - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. It's a tensor. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. It's a numpy array - valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() - else: - valid_sampled_token_ids = [] - invalid_req_indices = discard_sampled_tokens_req_indices.tolist( - ) - invalid_req_indices_set = set(invalid_req_indices) - if self.num_spec_tokens <= 0: - assert sampled_token_ids.shape[-1] == 1 - # Cache the sampled tokens on the NPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids - - - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set - self.input_batch.prev_req_id_to_index = { - req_id: i - for i, req_id in enumerate(self.input_batch.req_ids) - if i not in invalid_req_indices_set - } - # Cache the sampled tokens in the model runner, so that the scheduler - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - for req_idx in range(num_sampled_tokens): - if self.use_async_scheduling: - sampled_ids = [-1] * 1 if \ - req_idx not in invalid_req_indices_set else None - else: - sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: - continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.model_config.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.model_config.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, - start_idx:end_idx] = True - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) + sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( sampled_token_ids, - sampling_metadata, + self.input_batch.sampling_metadata, scheduler_output, spec_decode_metadata, positions, @@ -1867,6 +1735,22 @@ class NPUModelRunner(GPUModelRunner): aux_hidden_states, ) + ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + with ProfileExecuteDuration().capture_async("Draft"): if self.speculative_config: use_padded_batch_for_eagle = self.speculative_config and \ @@ -1920,13 +1804,164 @@ class NPUModelRunner(GPUModelRunner): self.debugger.step() return AsyncGPUModelRunnerOutput( model_runner_output=model_runner_output, - sampled_token_ids=sampled_token_ids, + sampled_token_ids=sampler_output.sampled_token_ids, logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, vocab_size=self.input_batch.vocab_size, ) + # overwrite _sample for lmhead_tp_enable and need_accepted_tokens + def _sample(self, logits, spec_decode_metadata): + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[:self.input_batch.num_reqs] + return self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + if lmhead_tp_enable() and logits is not None: + logits = logits[:len(spec_decode_metadata.logits_indices)] + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, + ) + if self.need_accepted_tokens: # TODO remove this if + self._update_states_after_model_execute( + sampler_output.sampled_token_ids) + return sampler_output + + # TODO: remove this func after eagle_proposer is refactored and + # _bookkeeping_sync is moved after propose_draft_token_ids + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> tuple[ + LogprobsLists | None, + list[list[int]], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], + ]: + # TODO: implement PR 28597 from vllm + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + logprobs_tensors = sampler_output.logprobs_tensors + invalid_req_indices = [] + cu_num_tokens: list[int] | None = None + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() + else: + # Includes spec decode tokens. + valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + discard_sampled_tokens_req_indices, + return_cu_num_tokens=logprobs_tensors is not None, + ) + else: + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + invalid_req_indices_set = set(invalid_req_indices) + + if self.num_spec_tokens <= 0: + assert sampled_token_ids.shape[-1] == 1 + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + req_ids = self.input_batch.req_ids + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [ + -1 + ] if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + + num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 + + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + num_sampled_ids + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + + req_id = req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens) + if not self.use_async_scheduling + and logprobs_tensors is not None else None) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + + return ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + def _build_dummy_attn_metadata( self, with_prefill: bool,