[Feat] Refactor rejection sampler (#4975)
### What this PR does / why we need it?
Currently, we are using `AscendRejctionSampler` that extends from
`RejctionSampler` in spec decoding. `AscendRejctionSampler` override
`forward` of `RejctionSampler`, only aming to replace `rejection_sample`
func. This
causes a lot of code of `RejctionSampler` cannot be reused, for example:
- https://github.com/vllm-project/vllm/pull/19482
- https://github.com/vllm-project/vllm/pull/26060
- https://github.com/vllm-project/vllm/pull/29223
#### Proposed Change:
- Delete `AscendRejctionSampler` and use `RejctionSampler` directly in
model runner.
- Patch `RejctionSampler.expand_batch_to_tokens` and
`RejctionSampler.rejection_sample`, maybe a better way is to make them
as custom ops.
- Modify `NPUModelRunner` following
https://github.com/vllm-project/vllm/pull/26060
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- [x] test logits processor for spec decoding
- [x] test logprobs for spec decoding
- [x] test logprobs for spec decoding + async shcheduling (test with
https://github.com/vllm-project/vllm-ascend/pull/4893/)
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -239,3 +240,56 @@ def test_suffix_acceptance(
|
|||||||
|
|
||||||
# Heuristic: expect at least 80% acceptance rate at the end.
|
# Heuristic: expect at least 80% acceptance rate at the end.
|
||||||
assert last_accept_rate > 0.60
|
assert last_accept_rate > 0.60
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
|
||||||
|
def test_eagle_logprobs(
|
||||||
|
model_name: str,
|
||||||
|
use_eagle3: bool,
|
||||||
|
):
|
||||||
|
prompt = {"role": "user", "content": "Hello world " * 10}
|
||||||
|
sampling_params = SamplingParams(temperature=0,
|
||||||
|
logprobs=1,
|
||||||
|
max_tokens=10,
|
||||||
|
ignore_eos=False)
|
||||||
|
|
||||||
|
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
|
||||||
|
ref_outputs = ref_llm.chat([prompt], sampling_params)
|
||||||
|
ref_logprobs = []
|
||||||
|
for output in ref_outputs[0].outputs:
|
||||||
|
for logprobs in output.logprobs:
|
||||||
|
for token_id in logprobs:
|
||||||
|
ref_logprobs.append(logprobs[token_id])
|
||||||
|
del ref_llm
|
||||||
|
|
||||||
|
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_num_batched_tokens=2048,
|
||||||
|
gpu_memory_utilization=0.6,
|
||||||
|
speculative_config={
|
||||||
|
"method": "eagle3" if use_eagle3 else "eagle",
|
||||||
|
"model": spec_model_name,
|
||||||
|
"num_speculative_tokens": 2,
|
||||||
|
"max_model_len": 128,
|
||||||
|
},
|
||||||
|
max_model_len=128,
|
||||||
|
enforce_eager=False,
|
||||||
|
) as runner:
|
||||||
|
spec_outputs = runner.model.chat([prompt], sampling_params)
|
||||||
|
|
||||||
|
# Collect logprobs outputs from spec decode LLM.
|
||||||
|
spec_logprobs = []
|
||||||
|
for output in spec_outputs[0].outputs:
|
||||||
|
for logprobs in output.logprobs:
|
||||||
|
for token_id in logprobs:
|
||||||
|
spec_logprobs.append(logprobs[token_id])
|
||||||
|
|
||||||
|
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
|
||||||
|
assert math.isclose(ref_logprob.logprob,
|
||||||
|
spec_logprob.logprob,
|
||||||
|
rel_tol=5e-2,
|
||||||
|
abs_tol=1e-1)
|
||||||
|
assert ref_logprob.rank == spec_logprob.rank
|
||||||
|
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||||
|
|||||||
@@ -228,7 +228,7 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when the bug is fixed.
|
# Remove this patch when the bug is fixed.
|
||||||
#
|
#
|
||||||
# ** File: worker/patch_qwen3_next_mtp.py**
|
# ** 11. File: worker/patch_qwen3_next_mtp.py**
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.v1.worker.utils.bind_kv_cache`
|
# 1. `vllm.v1.worker.utils.bind_kv_cache`
|
||||||
# Why:
|
# Why:
|
||||||
@@ -241,7 +241,7 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
|
||||||
#
|
#
|
||||||
# ** File: worker/patch_module.py**
|
# ** 12. File: worker/patch_module.py**
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
|
||||||
# Why:
|
# Why:
|
||||||
@@ -257,3 +257,19 @@
|
|||||||
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
|
||||||
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
|
||||||
#
|
#
|
||||||
|
# ** 13. File: worker/patch_rejection_sampler.py**
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `vllm.v1.sample.rejection_sampler`
|
||||||
|
# Why:
|
||||||
|
# - some functions from `rejection_sampler` are not supported or slow on npu.
|
||||||
|
# How:
|
||||||
|
# - add npu_top_k_top_p to 'apply_sampling_constraints' func
|
||||||
|
# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample`
|
||||||
|
# Related PR (if no, explain why):
|
||||||
|
# https://github.com/vllm-project/vllm/pull/874
|
||||||
|
# https://github.com/vllm-project/vllm/pull/4849
|
||||||
|
# Future Plan:
|
||||||
|
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
|
||||||
|
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
|
||||||
|
# 2. make these functions as costom op, then remove AscendRejectionSampler
|
||||||
|
#
|
||||||
@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa
|
|||||||
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
|
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
|
||||||
import vllm_ascend.patch.worker.patch_rope # noqa
|
import vllm_ascend.patch.worker.patch_rope # noqa
|
||||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
|
||||||
|
|||||||
11
vllm_ascend/patch/worker/patch_rejection_sampler.py
Normal file
11
vllm_ascend/patch/worker/patch_rejection_sampler.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import vllm.v1.sample.rejection_sampler as rs
|
||||||
|
|
||||||
|
from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints,
|
||||||
|
expand_batch_to_tokens,
|
||||||
|
rejection_sample)
|
||||||
|
|
||||||
|
# TODO: delete this patch after apply_sampling_constraints and rejection_sample
|
||||||
|
# are extracted to as class func of RejectionSampler
|
||||||
|
rs.apply_sampling_constraints = apply_sampling_constraints
|
||||||
|
rs.rejection_sample = rejection_sample
|
||||||
|
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
||||||
@@ -2,15 +2,11 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
import vllm.v1.sample.rejection_sampler as rs
|
|
||||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
|
from vllm.v1.sample.rejection_sampler import generate_uniform_probs
|
||||||
generate_uniform_probs)
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
||||||
|
|
||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||||
|
|
||||||
@@ -21,92 +17,6 @@ GREEDY_TEMPERATURE = -1
|
|||||||
MAX_SPEC_LEN = 32
|
MAX_SPEC_LEN = 32
|
||||||
|
|
||||||
|
|
||||||
class AscendRejectionSampler(RejectionSampler, nn.Module):
|
|
||||||
"""
|
|
||||||
The implementation strictly follows the algorithm described in
|
|
||||||
https://arxiv.org/abs/2211.17192.
|
|
||||||
However, we want to clarify the terminology used in the implementation:
|
|
||||||
accepted tokens: tokens that are accepted based on the relationship
|
|
||||||
between the "raw" draft and target probabilities.
|
|
||||||
recovered tokens: tokens that are sampled based on the adjusted probability
|
|
||||||
distribution, which is derived from both the draft and target
|
|
||||||
probabilities.
|
|
||||||
bonus tokens:
|
|
||||||
If all proposed tokens are accepted, the bonus token is added to the
|
|
||||||
end of the sequence. The bonus token is only sampled from the target
|
|
||||||
probabilities. We pass in the bonus tokens instead of sampling them
|
|
||||||
in the rejection sampler to allow for more flexibility in the
|
|
||||||
sampling process. For example, we can use top_p, top_k sampling for
|
|
||||||
bonus tokens, while spec decode does not support these sampling
|
|
||||||
strategies.
|
|
||||||
output tokens:
|
|
||||||
Tokens are finally generated with the rejection sampler.
|
|
||||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
metadata: SpecDecodeMetadata,
|
|
||||||
# [num_tokens, vocab_size]
|
|
||||||
draft_probs: Optional[torch.Tensor],
|
|
||||||
# [num_tokens, vocab_size]
|
|
||||||
target_logits: torch.Tensor,
|
|
||||||
# [batch_size, 1]
|
|
||||||
bonus_token_ids: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
'''
|
|
||||||
Args:
|
|
||||||
metadata:
|
|
||||||
Metadata for spec decoding.
|
|
||||||
draft_probs (Optional[torch.Tensor]):
|
|
||||||
Probability distribution for the draft tokens. Shape is
|
|
||||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
|
||||||
not provided, which is the case for ngram spec decode.
|
|
||||||
target_logits (torch.Tensor):
|
|
||||||
Target model's logits probability distribution.
|
|
||||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
|
||||||
different requests are flattened into a single tensor because
|
|
||||||
this is the shape of the output logits.
|
|
||||||
NOTE: `target_logits` can be updated in place to save memory.
|
|
||||||
bonus_token_ids_tensor (torch.Tensor):
|
|
||||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
|
||||||
Bonus tokens are added to the end of the sequence if all
|
|
||||||
proposed tokens are accepted. We generate the bonus tokens
|
|
||||||
outside of the rejection sampler with the default sampling
|
|
||||||
strategy. It allows for more flexibility in the sampling
|
|
||||||
process such as top_p, top_k sampling.
|
|
||||||
sampling_metadata (SamplingMetadata):
|
|
||||||
Additional metadata needed for sampling, such as temperature,
|
|
||||||
top-k/top-p parameters, or other relevant information.
|
|
||||||
Returns:
|
|
||||||
output_token_ids (torch.Tensor):
|
|
||||||
A tensor containing the final output token IDs.
|
|
||||||
'''
|
|
||||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
|
||||||
# [num_tokens, vocab_size]
|
|
||||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
|
||||||
# `compute_probs` function.
|
|
||||||
target_logits = apply_sampling_constraints(
|
|
||||||
target_logits,
|
|
||||||
metadata.cu_num_draft_tokens,
|
|
||||||
sampling_metadata,
|
|
||||||
)
|
|
||||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
|
||||||
|
|
||||||
output_token_ids = rejection_sample(
|
|
||||||
metadata.draft_token_ids,
|
|
||||||
metadata.num_draft_tokens,
|
|
||||||
metadata.max_spec_len,
|
|
||||||
metadata.cu_num_draft_tokens,
|
|
||||||
draft_probs,
|
|
||||||
target_probs,
|
|
||||||
bonus_token_ids,
|
|
||||||
sampling_metadata,
|
|
||||||
)
|
|
||||||
return output_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_constraints(
|
def apply_sampling_constraints(
|
||||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||||
@@ -844,6 +754,3 @@ def sample_recovered_tokens_kernel(
|
|||||||
tl.store(
|
tl.store(
|
||||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||||
orig_prob)
|
orig_prob)
|
||||||
|
|
||||||
|
|
||||||
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
|
||||||
|
|||||||
@@ -69,9 +69,11 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
|||||||
MambaSpec, MLAAttentionSpec,
|
MambaSpec, MLAAttentionSpec,
|
||||||
UniformTypeKVCacheSpecs)
|
UniformTypeKVCacheSpecs)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
ModelRunnerOutput,
|
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
||||||
|
SamplerOutput,
|
||||||
make_empty_encoder_model_runner_output)
|
make_empty_encoder_model_runner_output)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||||
@@ -112,7 +114,6 @@ from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
|||||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
|
||||||
from vllm_ascend.sample.sampler import AscendSampler
|
from vllm_ascend.sample.sampler import AscendSampler
|
||||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
@@ -424,7 +425,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
)
|
)
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.drafter = self._get_drafter()
|
self.drafter = self._get_drafter()
|
||||||
self.rejection_sampler = AscendRejectionSampler(self.sampler)
|
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||||
self.actual_seq_lengths_q = list(
|
self.actual_seq_lengths_q = list(
|
||||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||||
self.decode_token_per_req))
|
self.decode_token_per_req))
|
||||||
@@ -1353,7 +1354,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
logits_indices = logits_indices_pcp
|
logits_indices = logits_indices_pcp
|
||||||
metadata = SpecDecodeMetadata(
|
return SpecDecodeMetadata(
|
||||||
draft_token_ids=draft_token_ids,
|
draft_token_ids=draft_token_ids,
|
||||||
num_draft_tokens=num_draft_tokens.tolist(),
|
num_draft_tokens=num_draft_tokens.tolist(),
|
||||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||||
@@ -1362,7 +1363,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
bonus_logits_indices=bonus_logits_indices,
|
bonus_logits_indices=bonus_logits_indices,
|
||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
)
|
)
|
||||||
return metadata
|
|
||||||
|
|
||||||
def propose_draft_token_ids(
|
def propose_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
@@ -1719,145 +1719,13 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
grammar_output, logits)
|
grammar_output, logits)
|
||||||
|
|
||||||
with ProfileExecuteDuration().capture_async("Sample"):
|
with ProfileExecuteDuration().capture_async("Sample"):
|
||||||
# Sample the next token and get logprobs if needed.
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
|
||||||
if spec_decode_metadata is None:
|
|
||||||
if lmhead_tp_enable() and logits is not None:
|
|
||||||
logits = logits[:self.input_batch.num_reqs]
|
|
||||||
sampler_output = self.sampler(
|
|
||||||
logits=logits,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if lmhead_tp_enable() and logits is not None:
|
|
||||||
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
|
||||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
|
||||||
# creates a new tensor with separate storage from the original
|
|
||||||
# logits tensor. This means any in-place operations on bonus_logits
|
|
||||||
# won't affect the original logits tensor.
|
|
||||||
assert logits is not None
|
|
||||||
bonus_logits = logits[
|
|
||||||
spec_decode_metadata.bonus_logits_indices]
|
|
||||||
sampler_output = self.sampler(
|
|
||||||
logits=bonus_logits,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
)
|
|
||||||
bonus_token_ids = sampler_output.sampled_token_ids
|
|
||||||
|
|
||||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
|
||||||
# separate storage from the original `logits` tensor. Therefore,
|
|
||||||
# it is safe to update `target_logits` in place.
|
|
||||||
target_logits = logits[
|
|
||||||
spec_decode_metadata.target_logits_indices]
|
|
||||||
output_token_ids = self.rejection_sampler(
|
|
||||||
spec_decode_metadata,
|
|
||||||
None, # draft_probs
|
|
||||||
target_logits,
|
|
||||||
bonus_token_ids,
|
|
||||||
sampling_metadata,
|
|
||||||
)
|
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
|
||||||
if self.need_accepted_tokens:
|
|
||||||
self._update_states_after_model_execute(output_token_ids)
|
|
||||||
discard_sampled_tokens_req_indices = \
|
|
||||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
|
||||||
for i in discard_sampled_tokens_req_indices:
|
|
||||||
generator = self.input_batch.generators.get(int(i))
|
|
||||||
if generator is not None:
|
|
||||||
generator.set_offset(generator.get_offset() - 4)
|
|
||||||
|
|
||||||
# Copy some objects so they don't get modified after returning.
|
|
||||||
# This is important when using async scheduling.
|
|
||||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
|
||||||
req_id_to_index_output_copy = \
|
|
||||||
self.input_batch.req_id_to_index.copy()
|
|
||||||
|
|
||||||
# NOTE: NPU -> CPU Sync happens here.
|
|
||||||
# Move as many CPU operations as possible before this sync point.
|
|
||||||
logprobs_tensors = sampler_output.logprobs_tensors
|
|
||||||
logprobs_lists = logprobs_tensors.tolists() \
|
|
||||||
if logprobs_tensors is not None else None
|
|
||||||
|
|
||||||
# Compute prompt logprobs if needed.
|
|
||||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
|
||||||
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
|
||||||
scheduler_output.num_scheduled_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
|
||||||
|
|
||||||
if not self.use_async_scheduling:
|
|
||||||
# Get the valid generated tokens.
|
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
|
||||||
if max_gen_len == 1:
|
|
||||||
# No spec decode tokens. It's a tensor.
|
|
||||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
|
||||||
else:
|
|
||||||
# Includes spec decode tokens. It's a numpy array
|
|
||||||
valid_sampled_token_ids, _ = self.rejection_sampler.parse_output(
|
|
||||||
sampled_token_ids,
|
|
||||||
self.input_batch.vocab_size,
|
|
||||||
)
|
|
||||||
# Mask out the sampled tokens that should not be sampled.
|
|
||||||
for i in discard_sampled_tokens_req_indices:
|
|
||||||
valid_sampled_token_ids[int(i)].clear()
|
|
||||||
else:
|
|
||||||
valid_sampled_token_ids = []
|
|
||||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
|
|
||||||
)
|
|
||||||
invalid_req_indices_set = set(invalid_req_indices)
|
|
||||||
if self.num_spec_tokens <= 0:
|
|
||||||
assert sampled_token_ids.shape[-1] == 1
|
|
||||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
|
||||||
# These will be copied into input_ids in the next step
|
|
||||||
# when preparing inputs.
|
|
||||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
|
||||||
invalid_req_indices_set
|
|
||||||
self.input_batch.prev_req_id_to_index = {
|
|
||||||
req_id: i
|
|
||||||
for i, req_id in enumerate(self.input_batch.req_ids)
|
|
||||||
if i not in invalid_req_indices_set
|
|
||||||
}
|
|
||||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
|
||||||
# doesn't need to send them back.
|
|
||||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
|
||||||
# the sampled tokens back, because there's no direct communication
|
|
||||||
# between the first-stage worker and the last-stage worker.
|
|
||||||
for req_idx in range(num_sampled_tokens):
|
|
||||||
if self.use_async_scheduling:
|
|
||||||
sampled_ids = [-1] * 1 if \
|
|
||||||
req_idx not in invalid_req_indices_set else None
|
|
||||||
else:
|
|
||||||
sampled_ids = valid_sampled_token_ids[req_idx]
|
|
||||||
if not sampled_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
||||||
end_idx = start_idx + len(sampled_ids)
|
|
||||||
assert end_idx <= self.model_config.max_model_len, (
|
|
||||||
"Sampled token IDs exceed the max model length. "
|
|
||||||
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
||||||
f"{self.model_config.max_model_len}")
|
|
||||||
|
|
||||||
self.input_batch.token_ids_cpu[req_idx,
|
|
||||||
start_idx:end_idx] = sampled_ids
|
|
||||||
self.input_batch.is_token_ids[req_idx,
|
|
||||||
start_idx:end_idx] = True
|
|
||||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
||||||
self.input_batch.num_tokens[req_idx] = end_idx
|
|
||||||
req_id = self.input_batch.req_ids[req_idx]
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
req_state.output_token_ids.extend(sampled_ids)
|
|
||||||
|
|
||||||
def propose_draft_token_ids(sampled_token_ids):
|
def propose_draft_token_ids(sampled_token_ids):
|
||||||
assert self.spec_decode_common_attn_metadata is not None
|
assert self.spec_decode_common_attn_metadata is not None
|
||||||
self._draft_token_ids = self.propose_draft_token_ids(
|
self._draft_token_ids = self.propose_draft_token_ids(
|
||||||
sampled_token_ids,
|
sampled_token_ids,
|
||||||
sampling_metadata,
|
self.input_batch.sampling_metadata,
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
positions,
|
positions,
|
||||||
@@ -1867,6 +1735,22 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
logprobs_lists,
|
||||||
|
valid_sampled_token_ids,
|
||||||
|
prompt_logprobs_dict,
|
||||||
|
req_ids_output_copy,
|
||||||
|
req_id_to_index_output_copy,
|
||||||
|
invalid_req_indices,
|
||||||
|
) = self._bookkeeping_sync(
|
||||||
|
scheduler_output,
|
||||||
|
sampler_output,
|
||||||
|
logits,
|
||||||
|
hidden_states,
|
||||||
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
|
spec_decode_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
with ProfileExecuteDuration().capture_async("Draft"):
|
with ProfileExecuteDuration().capture_async("Draft"):
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
use_padded_batch_for_eagle = self.speculative_config and \
|
use_padded_batch_for_eagle = self.speculative_config and \
|
||||||
@@ -1920,13 +1804,164 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.debugger.step()
|
self.debugger.step()
|
||||||
return AsyncGPUModelRunnerOutput(
|
return AsyncGPUModelRunnerOutput(
|
||||||
model_runner_output=model_runner_output,
|
model_runner_output=model_runner_output,
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||||
invalid_req_indices=invalid_req_indices,
|
invalid_req_indices=invalid_req_indices,
|
||||||
async_output_copy_stream=self.async_output_copy_stream,
|
async_output_copy_stream=self.async_output_copy_stream,
|
||||||
vocab_size=self.input_batch.vocab_size,
|
vocab_size=self.input_batch.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# overwrite _sample for lmhead_tp_enable and need_accepted_tokens
|
||||||
|
def _sample(self, logits, spec_decode_metadata):
|
||||||
|
# Sample the next token and get logprobs if needed.
|
||||||
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
|
if spec_decode_metadata is None:
|
||||||
|
if lmhead_tp_enable() and logits is not None:
|
||||||
|
logits = logits[:self.input_batch.num_reqs]
|
||||||
|
return self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lmhead_tp_enable() and logits is not None:
|
||||||
|
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
||||||
|
sampler_output = self.rejection_sampler(
|
||||||
|
spec_decode_metadata,
|
||||||
|
None, # draft_probs
|
||||||
|
logits,
|
||||||
|
sampling_metadata,
|
||||||
|
)
|
||||||
|
if self.need_accepted_tokens: # TODO remove this if
|
||||||
|
self._update_states_after_model_execute(
|
||||||
|
sampler_output.sampled_token_ids)
|
||||||
|
return sampler_output
|
||||||
|
|
||||||
|
# TODO: remove this func after eagle_proposer is refactored and
|
||||||
|
# _bookkeeping_sync is moved after propose_draft_token_ids
|
||||||
|
def _bookkeeping_sync(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
sampler_output: SamplerOutput,
|
||||||
|
logits: torch.Tensor | None,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
num_scheduled_tokens: int,
|
||||||
|
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||||
|
) -> tuple[
|
||||||
|
LogprobsLists | None,
|
||||||
|
list[list[int]],
|
||||||
|
dict[str, LogprobsTensors | None],
|
||||||
|
list[str],
|
||||||
|
dict[str, int],
|
||||||
|
list[int],
|
||||||
|
]:
|
||||||
|
# TODO: implement PR 28597 from vllm
|
||||||
|
discard_sampled_tokens_req_indices = \
|
||||||
|
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||||
|
for i in discard_sampled_tokens_req_indices:
|
||||||
|
gen = self.input_batch.generators.get(int(i))
|
||||||
|
if gen is not None:
|
||||||
|
gen.set_offset(gen.get_offset() - 4)
|
||||||
|
|
||||||
|
# Copy some objects so they don't get modified after returning.
|
||||||
|
# This is important when using async scheduling.
|
||||||
|
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||||
|
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
||||||
|
|
||||||
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||||
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
|
invalid_req_indices = []
|
||||||
|
cu_num_tokens: list[int] | None = None
|
||||||
|
if not self.use_async_scheduling:
|
||||||
|
# Get the valid generated tokens.
|
||||||
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
|
if max_gen_len == 1:
|
||||||
|
# No spec decode tokens.
|
||||||
|
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||||
|
# Mask out the sampled tokens that should not be sampled.
|
||||||
|
for i in discard_sampled_tokens_req_indices:
|
||||||
|
valid_sampled_token_ids[int(i)].clear()
|
||||||
|
else:
|
||||||
|
# Includes spec decode tokens.
|
||||||
|
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||||
|
sampled_token_ids,
|
||||||
|
self.input_batch.vocab_size,
|
||||||
|
discard_sampled_tokens_req_indices,
|
||||||
|
return_cu_num_tokens=logprobs_tensors is not None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
valid_sampled_token_ids = []
|
||||||
|
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||||
|
invalid_req_indices_set = set(invalid_req_indices)
|
||||||
|
|
||||||
|
if self.num_spec_tokens <= 0:
|
||||||
|
assert sampled_token_ids.shape[-1] == 1
|
||||||
|
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||||
|
# These will be copied into input_ids in the next step
|
||||||
|
# when preparing inputs.
|
||||||
|
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||||
|
|
||||||
|
self.input_batch.prev_req_id_to_index = {
|
||||||
|
req_id: i
|
||||||
|
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||||
|
if i not in invalid_req_indices_set
|
||||||
|
}
|
||||||
|
|
||||||
|
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||||
|
# doesn't need to send them back.
|
||||||
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||||
|
# the sampled tokens back, because there's no direct communication
|
||||||
|
# between the first-stage worker and the last-stage worker.
|
||||||
|
req_ids = self.input_batch.req_ids
|
||||||
|
for req_idx in range(num_sampled_tokens):
|
||||||
|
if self.use_async_scheduling:
|
||||||
|
sampled_ids = [
|
||||||
|
-1
|
||||||
|
] if req_idx not in invalid_req_indices_set else None
|
||||||
|
else:
|
||||||
|
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||||
|
|
||||||
|
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
|
||||||
|
|
||||||
|
if not sampled_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
||||||
|
end_idx = start_idx + num_sampled_ids
|
||||||
|
assert end_idx <= self.max_model_len, (
|
||||||
|
"Sampled token IDs exceed the max model length. "
|
||||||
|
f"Total number of tokens: {end_idx} > max_model_len: "
|
||||||
|
f"{self.max_model_len}")
|
||||||
|
|
||||||
|
self.input_batch.token_ids_cpu[req_idx,
|
||||||
|
start_idx:end_idx] = sampled_ids
|
||||||
|
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
|
||||||
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||||
|
self.input_batch.num_tokens[req_idx] = end_idx
|
||||||
|
|
||||||
|
req_id = req_ids[req_idx]
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
|
logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens)
|
||||||
|
if not self.use_async_scheduling
|
||||||
|
and logprobs_tensors is not None else None)
|
||||||
|
|
||||||
|
# Compute prompt logprobs if needed.
|
||||||
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||||
|
hidden_states[:num_scheduled_tokens],
|
||||||
|
scheduler_output.num_scheduled_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
logprobs_lists,
|
||||||
|
valid_sampled_token_ids,
|
||||||
|
prompt_logprobs_dict,
|
||||||
|
req_ids_output_copy,
|
||||||
|
req_id_to_index_output_copy,
|
||||||
|
invalid_req_indices,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_dummy_attn_metadata(
|
def _build_dummy_attn_metadata(
|
||||||
self,
|
self,
|
||||||
with_prefill: bool,
|
with_prefill: bool,
|
||||||
|
|||||||
Reference in New Issue
Block a user