Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -10,12 +10,14 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.logits_processor.builtin import MinTokensLogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -292,6 +294,12 @@ class RejectionSampler(nn.Module):
|
||||
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
if isinstance(processor, MinTokensLogitsProcessor):
|
||||
logits = processor.apply_with_spec_decode(
|
||||
logits, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
@@ -385,14 +393,13 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_logits.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size,)](
|
||||
ops.rejection_greedy_sample_torch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
@@ -424,7 +431,7 @@ def rejection_sample(
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size,)](
|
||||
ops.rejection_random_sample_torch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -434,8 +441,6 @@ def rejection_sample(
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
Reference in New Issue
Block a user