[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
"""Interface for logic that processes new token ids in sequence groups,
managing detokenization, stop checking, and freeing/forking sequences with
the scheduler.
This is highly coupled with the LLMEngine and should be seen as an extension
of it. The logic is separated to simplify the LLMEngine class and allow
separate implementations for single-step decoding (which supports beam
search sequence forking) and multi-step decoding (which does not support
beam search, but does support speculative decoding).
"""
@staticmethod
def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(scheduler_config, detokenizer,
scheduler, seq_counter,
stop_checker)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
MultiStepOutputProcessor)
return MultiStepOutputProcessor(
detokenizer,
scheduler,
seq_counter,
get_tokenizer_for_seq,
stop_checker,
)
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
"""
pass
@abstractmethod
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Update prompt logprobs received from outputs to seq_group."""
pass

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Callable, List, cast
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.single_step import (
single_step_process_prompt_logprob)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles logic related to
detokenization and stopping conditions. It specializes to "multi-step
decoding", where vLLM's worker may generate multiple tokens per invocation.
This is currently mutually exclusive with advanced sampling techniques like
beam search, which motivates the separation of this logic from the single
step output processor.
This class is responsible for things such as correctly appending all new
token ids to their sequence, detokenizing new token ids, truncating new
output tokens after an eos token, and correctly handling the case where the
number of new output tokens per sequence differs in a single batch.
"""
def __init__(
self,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with each step of a multi-step-
scheduled computation.
Args:
seq_group: the outputs are associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s
for all scheduler steps
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod
@functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization.
It also handles cases where there are tokens emitted after
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINISHED_ABORTED
# if a client disconnects from the api server.
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
if seqs is None:
seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
seq_id = seq.seq_id
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert all([
isinstance(output, CompletionSequenceGroupOutput)
for output in outputs
])
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
assert all([
seq_id == output.samples[0].parent_seq_id
for output in compl_outputs
])
if is_async:
# Async case: We process tokens one by one. Here, we know the token
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
else:
# Standard multi-step case
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples = [output.samples[0] for output in compl_outputs]
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
# When both spec-decode and pre-fill chunking are enabled, we
# don't have guaranteed samples here (e.g. all -1s).
if valid_samples:
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
new_char_count = 0
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
output_embeds = [sample.output_embed for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if not sampling_params.ignore_eos and self.detokenizer:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
break
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id, output_logprob, output_embed in zip(
output_token_ids, output_logprobs, output_embeds):
seq.append_token_id(
token_id=output_token_id,
logprobs=output_logprob,
token_embed=output_embed,
)
if is_prefill_sampled_token:
is_prefill_sampled_token = False
else:
# Update num_computed_tokens iff the sampled token is not from
# a prefill step.
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished():
break

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
SequenceGroupOutput)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc:
[`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor]
instance
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
"""
prompt_logprobs = output.prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
assert hasattr(sg_output_proc, 'detokenizer')
if (seq_group.sampling_params.detokenize
and sg_output_proc.detokenizer):
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
scheduling of the next batch. Output processing logic includes
detokenization, and determining if a sequence is finished (e.g. via max len
or eos token).
The SingleStepOutputProcessor is specialized to the case where the model
emits at most a single token per invocation, which precludes configurations
such as speculative decoding or multi-step decoding. This enables beam
search sampling, which requires forking/finishing/freeing sequences in a way
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(self, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, scheduler: List[Scheduler],
seq_counter: Counter, stop_checker: StopChecker):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0],
is_async)
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
"""
assert len(outputs) == 1, "Single step should only have 1 output."
output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop
checking. This checks things such as: whether the eos token was emitted,
whether the max_tokens has been consumed, whether a stop string has been
emitted, or if we have exceeded the max model len.
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
else:
return self._max_model_len
def maybe_stop_sequence(
self,
seq: Sequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() >= self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: List[str],
include_in_output: bool,
) -> Optional[Tuple[str, int]]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns tuple (stop_string, offset) if matched or else None.
Where stop_string is the matched stop string and offset is the
length to which output_text should be truncated, or -1 for no
truncation.
"""
if not new_char_count or not stop:
return None
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = output_text.find(stop_str,
1 - new_char_count - stop_string_len)
if stop_index == -1:
continue
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(output_text):
# No truncation required.
return stop_str, -1
# Truncate the output text to either the beginning
# or end of the stop string.
return stop_str, stop_index
return None

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List
from typing import Sequence as GenericSequence
from typing import cast
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
def create_output_by_sequence_group(
outputs: GenericSequence[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in outputs:
sequence_group_output: CompletionSequenceGroupOutput
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)