Sync from v0.13
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
@CustomOp.register("logits_processor")
|
||||
class LogitsProcessor(CustomOp):
|
||||
"""Process logits and apply logits processors from sampling metadata.
|
||||
|
||||
This layer does the following:
|
||||
@@ -17,11 +23,14 @@ class LogitsProcessor(nn.Module):
|
||||
3. Apply logits processors (if any).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
org_vocab_size: Optional[int] = None,
|
||||
scale: Optional[float] = 1.0,
|
||||
logits_as_input: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
org_vocab_size: int | None = None,
|
||||
scale: float = 1.0,
|
||||
logits_as_input: bool = False,
|
||||
soft_cap: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
scale: A scaling factor to apply to the logits.
|
||||
@@ -33,83 +42,65 @@ class LogitsProcessor(nn.Module):
|
||||
self.logits_as_input = logits_as_input
|
||||
# original vocabulary size (without LoRA).
|
||||
self.org_vocab_size = org_vocab_size or vocab_size
|
||||
# Soft cap the logits. Used in Gemma 2.
|
||||
self.soft_cap = soft_cap
|
||||
# Whether to use gather or all-gather to gather the logits.
|
||||
self.use_all_gather = current_platform.use_all_gather()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if logits is not None:
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.soft_cap
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
return logits
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""gather/all-gather the logits tensor across model parallel group."""
|
||||
if self.use_all_gather:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
else:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
return logits
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
|
||||
|
||||
# Gather logits for TP
|
||||
logits = self._gather_logits(logits)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
logits = logits[..., : self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"vocab_size={self.vocab_size}"
|
||||
s += f", forg_vocab_size={self.org_vocab_size}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
||||
return s
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states.index_select(0,
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
found_logits_processors = False
|
||||
logits_processed = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
logits_processors = sampling_params.logits_processors
|
||||
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id, logits_row_idx in zip(seq_ids,
|
||||
seq_group.sample_indices):
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = seq_group.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
|
||||
logits_processed += len(seq_group.sample_indices) + len(
|
||||
seq_group.prompt_logprob_indices)
|
||||
|
||||
if found_logits_processors:
|
||||
# verifies that no rows in logits were missed unexpectedly
|
||||
assert logits_processed == logits.shape[0]
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user