[Feat] Refactor rejection sampler (#4975)
### What this PR does / why we need it?
Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- https://github.com/vllm-project/vllm/pull/19482
- https://github.com/vllm-project/vllm/pull/26060
- https://github.com/vllm-project/vllm/pull/29223
#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
https://github.com/vllm-project/vllm/pull/26060
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
https://github.com/vllm-project/vllm-ascend/pull/4893/)
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Any
|
||||
@@ -239,3 +240,56 @@ def test_suffix_acceptance(
|
||||
|
||||
# Heuristic: expect at least 80% acceptance rate at the end.
|
||||
assert last_accept_rate > 0.60
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
|
||||
def test_eagle_logprobs(
|
||||
model_name: str,
|
||||
use_eagle3: bool,
|
||||
):
|
||||
prompt = {"role": "user", "content": "Hello world " * 10}
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
logprobs=1,
|
||||
max_tokens=10,
|
||||
ignore_eos=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
|
||||
ref_outputs = ref_llm.chat([prompt], sampling_params)
|
||||
ref_logprobs = []
|
||||
for output in ref_outputs[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
ref_logprobs.append(logprobs[token_id])
|
||||
del ref_llm
|
||||
|
||||
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
max_num_seqs=1,
|
||||
max_num_batched_tokens=2048,
|
||||
gpu_memory_utilization=0.6,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 2,
|
||||
"max_model_len": 128,
|
||||
},
|
||||
max_model_len=128,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
spec_outputs = runner.model.chat([prompt], sampling_params)
|
||||
|
||||
# Collect logprobs outputs from spec decode LLM.
|
||||
spec_logprobs = []
|
||||
for output in spec_outputs[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
spec_logprobs.append(logprobs[token_id])
|
||||
|
||||
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
|
||||
assert math.isclose(ref_logprob.logprob,
|
||||
spec_logprob.logprob,
|
||||
rel_tol=5e-2,
|
||||
abs_tol=1e-1)
|
||||
assert ref_logprob.rank == spec_logprob.rank
|
||||
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||
|
||||
@@ -228,7 +228,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when the bug is fixed.
|
||||
#
|
||||
# ** File: worker/patch_qwen3_next_mtp.py**
|
||||
# ** 11. File: worker/patch_qwen3_next_mtp.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.worker.utils.bind_kv_cache`
|
||||
# Why:
|
||||
@@ -241,7 +241,7 @@
|
||||
# Future Plan:
|
||||
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
||||
#
|
||||
# ** File: worker/patch_module.py**
|
||||
# ** 12. File: worker/patch_module.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
||||
# Why:
|
||||
@@ -257,3 +257,19 @@
|
||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
||||
#
|
||||
# ** 13. File: worker/patch_rejection_sampler.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.sample.rejection_sampler`
|
||||
# Why:
|
||||
# - some functions from `rejection_sampler` are not supported or slow on npu.
|
||||
# How:
|
||||
# - add npu_top_k_top_p to 'apply_sampling_constraints' func
|
||||
# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample`
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/874
|
||||
# https://github.com/vllm-project/vllm/pull/4849
|
||||
# Future Plan:
|
||||
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
|
||||
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
|
||||
# 2. make these functions as costom op, then remove AscendRejectionSampler
|
||||
#
|
||||
@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_rope # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
|
||||
|
||||
11
vllm_ascend/patch/worker/patch_rejection_sampler.py
Normal file
11
vllm_ascend/patch/worker/patch_rejection_sampler.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import vllm.v1.sample.rejection_sampler as rs
|
||||
|
||||
from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints,
|
||||
expand_batch_to_tokens,
|
||||
rejection_sample)
|
||||
|
||||
# TODO: delete this patch after apply_sampling_constraints and rejection_sample
|
||||
# are extracted to as class func of RejectionSampler
|
||||
rs.apply_sampling_constraints = apply_sampling_constraints
|
||||
rs.rejection_sample = rejection_sample
|
||||
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
||||
@@ -2,15 +2,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.v1.sample.rejection_sampler as rs
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
|
||||
generate_uniform_probs)
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.sample.rejection_sampler import generate_uniform_probs
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
@@ -21,92 +17,6 @@ GREEDY_TEMPERATURE = -1
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
class AscendRejectionSampler(RejectionSampler, nn.Module):
|
||||
"""
|
||||
The implementation strictly follows the algorithm described in
|
||||
https://arxiv.org/abs/2211.17192.
|
||||
However, we want to clarify the terminology used in the implementation:
|
||||
accepted tokens: tokens that are accepted based on the relationship
|
||||
between the "raw" draft and target probabilities.
|
||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
||||
distribution, which is derived from both the draft and target
|
||||
probabilities.
|
||||
bonus tokens:
|
||||
If all proposed tokens are accepted, the bonus token is added to the
|
||||
end of the sequence. The bonus token is only sampled from the target
|
||||
probabilities. We pass in the bonus tokens instead of sampling them
|
||||
in the rejection sampler to allow for more flexibility in the
|
||||
sampling process. For example, we can use top_p, top_k sampling for
|
||||
bonus tokens, while spec decode does not support these sampling
|
||||
strategies.
|
||||
output tokens:
|
||||
Tokens are finally generated with the rejection sampler.
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
sampling_metadata (SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_logits = apply_sampling_constraints(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def apply_sampling_constraints(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
@@ -844,6 +754,3 @@ def sample_recovered_tokens_kernel(
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
|
||||
|
||||
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
||||
|
||||
@@ -69,9 +69,11 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
MambaSpec, MLAAttentionSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
ModelRunnerOutput,
|
||||
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
make_empty_encoder_model_runner_output)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
@@ -112,7 +114,6 @@ from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
@@ -424,7 +425,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = self._get_drafter()
|
||||
self.rejection_sampler = AscendRejectionSampler(self.sampler)
|
||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
@@ -1353,7 +1354,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
if self.pcp_size > 1:
|
||||
logits_indices = logits_indices_pcp
|
||||
metadata = SpecDecodeMetadata(
|
||||
return SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
@@ -1362,7 +1363,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
@@ -1719,145 +1719,13 @@ class NPUModelRunner(GPUModelRunner):
|
||||
grammar_output, logits)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Sample"):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:self.input_batch.num_reqs]
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[
|
||||
spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[
|
||||
spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
generator = self.input_batch.generators.get(int(i))
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||
req_id_to_index_output_copy = \
|
||||
self.input_batch.req_id_to_index.copy()
|
||||
|
||||
# NOTE: NPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens. It's a tensor.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens. It's a numpy array
|
||||
valid_sampled_token_ids, _ = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
|
||||
)
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
if self.num_spec_tokens <= 0:
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
|
||||
|
||||
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
||||
invalid_req_indices_set
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||
if i not in invalid_req_indices_set
|
||||
}
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [-1] * 1 if \
|
||||
req_idx not in invalid_req_indices_set else None
|
||||
else:
|
||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + len(sampled_ids)
|
||||
assert end_idx <= self.model_config.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.model_config.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.is_token_ids[req_idx,
|
||||
start_idx:end_idx] = True
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
req_id = self.input_batch.req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert self.spec_decode_common_attn_metadata is not None
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
sampled_token_ids,
|
||||
sampling_metadata,
|
||||
self.input_batch.sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
@@ -1867,6 +1735,22 @@ class NPUModelRunner(GPUModelRunner):
|
||||
aux_hidden_states,
|
||||
)
|
||||
|
||||
(
|
||||
logprobs_lists,
|
||||
valid_sampled_token_ids,
|
||||
prompt_logprobs_dict,
|
||||
req_ids_output_copy,
|
||||
req_id_to_index_output_copy,
|
||||
invalid_req_indices,
|
||||
) = self._bookkeeping_sync(
|
||||
scheduler_output,
|
||||
sampler_output,
|
||||
logits,
|
||||
hidden_states,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
spec_decode_metadata,
|
||||
)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Draft"):
|
||||
if self.speculative_config:
|
||||
use_padded_batch_for_eagle = self.speculative_config and \
|
||||
@@ -1920,13 +1804,164 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.debugger.step()
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
vocab_size=self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
# overwrite _sample for lmhead_tp_enable and need_accepted_tokens
|
||||
def _sample(self, logits, spec_decode_metadata):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:self.input_batch.num_reqs]
|
||||
return self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
if lmhead_tp_enable() and logits is not None:
|
||||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||||
sampler_output = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
logits,
|
||||
sampling_metadata,
|
||||
)
|
||||
if self.need_accepted_tokens: # TODO remove this if
|
||||
self._update_states_after_model_execute(
|
||||
sampler_output.sampled_token_ids)
|
||||
return sampler_output
|
||||
|
||||
# TODO: remove this func after eagle_proposer is refactored and
|
||||
# _bookkeeping_sync is moved after propose_draft_token_ids
|
||||
def _bookkeeping_sync(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
sampler_output: SamplerOutput,
|
||||
logits: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||
) -> tuple[
|
||||
LogprobsLists | None,
|
||||
list[list[int]],
|
||||
dict[str, LogprobsTensors | None],
|
||||
list[str],
|
||||
dict[str, int],
|
||||
list[int],
|
||||
]:
|
||||
# TODO: implement PR 28597 from vllm
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
gen = self.input_batch.generators.get(int(i))
|
||||
if gen is not None:
|
||||
gen.set_offset(gen.get_offset() - 4)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
invalid_req_indices = []
|
||||
cu_num_tokens: list[int] | None = None
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
discard_sampled_tokens_req_indices,
|
||||
return_cu_num_tokens=logprobs_tensors is not None,
|
||||
)
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
|
||||
if self.num_spec_tokens <= 0:
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||
if i not in invalid_req_indices_set
|
||||
}
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [
|
||||
-1
|
||||
] if req_idx not in invalid_req_indices_set else None
|
||||
else:
|
||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||
|
||||
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
|
||||
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
assert end_idx <= self.max_model_len, (
|
||||
"Sampled token IDs exceed the max model length. "
|
||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||
f"{self.max_model_len}")
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
|
||||
req_id = req_ids[req_idx]
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens)
|
||||
if not self.use_async_scheduling
|
||||
and logprobs_tensors is not None else None)
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
return (
|
||||
logprobs_lists,
|
||||
valid_sampled_token_ids,
|
||||
prompt_logprobs_dict,
|
||||
req_ids_output_copy,
|
||||
req_id_to_index_output_copy,
|
||||
invalid_req_indices,
|
||||
)
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
|
||||
Reference in New Issue
Block a user