Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -202,10 +202,11 @@ def build_logitsprocs(
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
|
||||
logger.warning(
|
||||
"min_p, logit_bias, and min_tokens parameters won't currently work "
|
||||
"with speculative decoding enabled."
|
||||
"min_p and logit_bias parameters won't work with speculative decoding."
|
||||
)
|
||||
return LogitsProcessors(
|
||||
[MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)]
|
||||
)
|
||||
return LogitsProcessors()
|
||||
|
||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
@@ -236,6 +237,59 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
|
||||
return logits
|
||||
|
||||
def apply_with_spec_decode(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
num_draft_tokens: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Spec-decode version of apply().
|
||||
Priority: ``min_tokens`` > ``stop_token_ids`` / EOS.
|
||||
Example: ``num_draft_tokens = [2, 3, 1]``
|
||||
→ ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]``
|
||||
→ request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5.
|
||||
"""
|
||||
if not self.min_toks:
|
||||
return logits
|
||||
|
||||
num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
|
||||
cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])
|
||||
|
||||
entries = [
|
||||
(req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
|
||||
for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
|
||||
if stop_tok_ids
|
||||
]
|
||||
|
||||
if not entries:
|
||||
return logits
|
||||
|
||||
all_rows: list[np.ndarray] = [] # row indices to mask
|
||||
all_toks: list[np.ndarray] = [] # stop-token ids at those rows
|
||||
|
||||
for req_idx, min_tok, current_len, stop_toks in entries:
|
||||
remaining = min_tok - current_len
|
||||
# How many leading draft positions still need stop-token masking.
|
||||
n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))
|
||||
|
||||
if n_mask > 0:
|
||||
offset = cumsum[req_idx]
|
||||
row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
|
||||
n_stop = len(stop_toks)
|
||||
all_rows.append(np.repeat(row_indices, n_stop))
|
||||
all_toks.append(np.tile(stop_toks, n_mask))
|
||||
|
||||
if all_rows:
|
||||
rows_arr = np.concatenate(all_rows)
|
||||
toks_arr = np.concatenate(all_toks)
|
||||
# (row_indices, token_indices) for index_put_ to set -inf.
|
||||
logits_slice = (
|
||||
torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
|
||||
torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
|
||||
)
|
||||
logits.index_put_(logits_slice, self.neg_inf_tensor)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def process_dict_updates(
|
||||
req_entries: dict[int, T],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Iterable, Iterator
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -148,7 +148,7 @@ class BatchUpdateBuilder:
|
||||
class LogitsProcessors:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
|
||||
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
|
||||
def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
|
||||
self.argmax_invariant: list[LogitsProcessor] = []
|
||||
self.non_argmax_invariant: list[LogitsProcessor] = []
|
||||
if logitsprocs:
|
||||
|
||||
@@ -10,12 +10,14 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.logits_processor.builtin import MinTokensLogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -292,6 +294,12 @@ class RejectionSampler(nn.Module):
|
||||
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
if isinstance(processor, MinTokensLogitsProcessor):
|
||||
logits = processor.apply_with_spec_decode(
|
||||
logits, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
@@ -385,14 +393,13 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_logits.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size,)](
|
||||
ops.rejection_greedy_sample_torch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
@@ -424,7 +431,7 @@ def rejection_sample(
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size,)](
|
||||
ops.rejection_random_sample_torch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -434,8 +441,6 @@ def rejection_sample(
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
Reference in New Issue
Block a user