[Feat] Refactor rejection sampler (#4975)

### What this PR does / why we need it?

Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- https://github.com/vllm-project/vllm/pull/19482
- https://github.com/vllm-project/vllm/pull/26060
- https://github.com/vllm-project/vllm/pull/29223

#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
https://github.com/vllm-project/vllm/pull/26060

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
https://github.com/vllm-project/vllm-ascend/pull/4893/)


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-12-16 11:32:26 +08:00
committed by GitHub
parent 5f840696c1
commit 9e24bdd44c
6 changed files with 260 additions and 236 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import math
import os
import random
from typing import Any
@@ -239,3 +240,56 @@ def test_suffix_acceptance(
# Heuristic: expect at least 80% acceptance rate at the end.
assert last_accept_rate > 0.60
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
def test_eagle_logprobs(
model_name: str,
use_eagle3: bool,
):
prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0,
logprobs=1,
max_tokens=10,
ignore_eos=False)
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
ref_outputs = ref_llm.chat([prompt], sampling_params)
ref_logprobs = []
for output in ref_outputs[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
enforce_eager=False,
) as runner:
spec_outputs = runner.model.chat([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_outputs[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob,
spec_logprob.logprob,
rel_tol=5e-2,
abs_tol=1e-1)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token

View File

@@ -228,7 +228,7 @@
# Future Plan:
# Remove this patch when the bug is fixed.
#
# ** File: worker/patch_qwen3_next_mtp.py**
# ** 11. File: worker/patch_qwen3_next_mtp.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.worker.utils.bind_kv_cache`
# Why:
@@ -241,7 +241,7 @@
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
# ** File: worker/patch_module.py**
# ** 12. File: worker/patch_module.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
@@ -257,3 +257,19 @@
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
#
# ** 13. File: worker/patch_rejection_sampler.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.rejection_sampler`
# Why:
# - some functions from `rejection_sampler` are not supported or slow on npu.
# How
# - add npu_top_k_top_p to 'apply_sampling_constraints' func
# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample`
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/874
# https://github.com/vllm-project/vllm/pull/4849
# Future Plan:
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
# 2. make these functions as costom op, then remove AscendRejectionSampler
#

View File

@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
import vllm_ascend.patch.worker.patch_rope # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa

View File

@@ -0,0 +1,11 @@
import vllm.v1.sample.rejection_sampler as rs
from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints,
expand_batch_to_tokens,
rejection_sample)
# TODO: delete this patch after apply_sampling_constraints and rejection_sample
# are extracted to as class func of RejectionSampler
rs.apply_sampling_constraints = apply_sampling_constraints
rs.rejection_sample = rejection_sample
rs.expand_batch_to_tokens = expand_batch_to_tokens

View File

@@ -2,15 +2,11 @@
from typing import Optional
import torch
import torch.nn as nn
import torch_npu
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.sample.rejection_sampler import generate_uniform_probs
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
@@ -21,92 +17,6 @@ GREEDY_TEMPERATURE = -1
MAX_SPEC_LEN = 32
class AscendRejectionSampler(RejectionSampler, nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_logits = apply_sampling_constraints(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
return output_token_ids
def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
@@ -844,6 +754,3 @@ def sample_recovered_tokens_kernel(
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)
rs.expand_batch_to_tokens = expand_batch_to_tokens

View File

@@ -69,9 +69,11 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
MambaSpec, MLAAttentionSpec,
UniformTypeKVCacheSpecs)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
ModelRunnerOutput,
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
SamplerOutput,
make_empty_encoder_model_runner_output)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
@@ -112,7 +114,6 @@ from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.sample.sampler import AscendSampler
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
@@ -424,7 +425,7 @@ class NPUModelRunner(GPUModelRunner):
)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
self.rejection_sampler = AscendRejectionSampler(self.sampler)
self.rejection_sampler = RejectionSampler(self.sampler)
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
@@ -1353,7 +1354,7 @@ class NPUModelRunner(GPUModelRunner):
draft_token_ids = draft_token_ids[target_logits_indices + 1]
if self.pcp_size > 1:
logits_indices = logits_indices_pcp
metadata = SpecDecodeMetadata(
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
@@ -1362,7 +1363,6 @@ class NPUModelRunner(GPUModelRunner):
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def propose_draft_token_ids(
self,
@@ -1719,145 +1719,13 @@ class NPUModelRunner(GPUModelRunner):
grammar_output, logits)
with ProfileExecuteDuration().capture_async("Sample"):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
if lmhead_tp_enable() and logits is not None:
logits = logits[:self.input_batch.num_reqs]
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
if lmhead_tp_enable() and logits is not None:
logits = logits[:len(spec_decode_metadata.logits_indices)]
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[
spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[
spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices = \
self.discard_request_indices.np[:self.num_discarded_requests]
for i in discard_sampled_tokens_req_indices:
generator = self.input_batch.generators.get(int(i))
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
req_ids_output_copy = self.input_batch.req_ids.copy()
req_id_to_index_output_copy = \
self.input_batch.req_id_to_index.copy()
# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:scheduler_output.total_num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens. It's a tensor.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens. It's a numpy array
valid_sampled_token_ids, _ = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
)
invalid_req_indices_set = set(invalid_req_indices)
if self.num_spec_tokens <= 0:
assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = sampled_token_ids
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
if i not in invalid_req_indices_set
}
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] * 1 if \
req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.model_config.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.model_config.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx,
start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(sampled_token_ids):
assert self.spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
sampled_token_ids,
sampling_metadata,
self.input_batch.sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
@@ -1867,6 +1735,22 @@ class NPUModelRunner(GPUModelRunner):
aux_hidden_states,
)
(
logprobs_lists,
valid_sampled_token_ids,
prompt_logprobs_dict,
req_ids_output_copy,
req_id_to_index_output_copy,
invalid_req_indices,
) = self._bookkeeping_sync(
scheduler_output,
sampler_output,
logits,
hidden_states,
scheduler_output.total_num_scheduled_tokens,
spec_decode_metadata,
)
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
use_padded_batch_for_eagle = self.speculative_config and \
@@ -1920,13 +1804,164 @@ class NPUModelRunner(GPUModelRunner):
self.debugger.step()
return AsyncGPUModelRunnerOutput(
model_runner_output=model_runner_output,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
vocab_size=self.input_batch.vocab_size,
)
# overwrite _sample for lmhead_tp_enable and need_accepted_tokens
def _sample(self, logits, spec_decode_metadata):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
if lmhead_tp_enable() and logits is not None:
logits = logits[:self.input_batch.num_reqs]
return self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
if lmhead_tp_enable() and logits is not None:
logits = logits[:len(spec_decode_metadata.logits_indices)]
sampler_output = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
logits,
sampling_metadata,
)
if self.need_accepted_tokens: # TODO remove this if
self._update_states_after_model_execute(
sampler_output.sampled_token_ids)
return sampler_output
# TODO: remove this func after eagle_proposer is refactored and
# _bookkeeping_sync is moved after propose_draft_token_ids
def _bookkeeping_sync(
self,
scheduler_output: "SchedulerOutput",
sampler_output: SamplerOutput,
logits: torch.Tensor | None,
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
spec_decode_metadata: SpecDecodeMetadata | None,
) -> tuple[
LogprobsLists | None,
list[list[int]],
dict[str, LogprobsTensors | None],
list[str],
dict[str, int],
list[int],
]:
# TODO: implement PR 28597 from vllm
discard_sampled_tokens_req_indices = \
self.discard_request_indices.np[:self.num_discarded_requests]
for i in discard_sampled_tokens_req_indices:
gen = self.input_batch.generators.get(int(i))
if gen is not None:
gen.set_offset(gen.get_offset() - 4)
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
req_ids_output_copy = self.input_batch.req_ids.copy()
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
discard_sampled_tokens_req_indices,
return_cu_num_tokens=logprobs_tensors is not None,
)
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
invalid_req_indices_set = set(invalid_req_indices)
if self.num_spec_tokens <= 0:
assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = sampled_token_ids
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
if i not in invalid_req_indices_set
}
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [
-1
] if req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + num_sampled_ids
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling
and logprobs_tensors is not None else None)
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
return (
logprobs_lists,
valid_sampled_token_ids,
prompt_logprobs_dict,
req_ids_output_copy,
req_id_to_index_output_copy,
invalid_req_indices,
)
def _build_dummy_attn_metadata(
self,
with_prefill: bool,