281 lines
9.2 KiB
Python
281 lines
9.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
|
|
|
MAX_NUM_ALLOWED_TOKEN_IDS = 1024
|
|
MAX_NUM_LOGIT_BIAS_TOKENS = 1024
|
|
MAX_NUM_STOP_TOKEN_IDS = 128
|
|
|
|
|
|
class LogitBiasState:
|
|
def __init__(self, max_num_reqs: int, device: torch.device):
|
|
self.max_num_reqs = max_num_reqs
|
|
|
|
# Allowed token IDs.
|
|
self.num_allowed_token_ids = UvaBackedTensor(
|
|
self.max_num_reqs, dtype=torch.int32
|
|
)
|
|
self.allowed_token_ids = StagedWriteTensor(
|
|
(self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
# Logit bias.
|
|
self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
|
self.logit_bias_token_ids = StagedWriteTensor(
|
|
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
self.logit_bias = StagedWriteTensor(
|
|
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
# Min tokens.
|
|
self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
|
self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
|
self.stop_token_ids = StagedWriteTensor(
|
|
(self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
# Using any of the above.
|
|
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
|
|
|
def add_request(
|
|
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
|
) -> None:
|
|
# Using any logit bias.
|
|
use_logit_bias = False
|
|
|
|
# Allowed token IDs.
|
|
allowed_token_ids = sampling_params.allowed_token_ids
|
|
if allowed_token_ids:
|
|
num_allowed_token_ids = len(allowed_token_ids)
|
|
if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
|
|
raise ValueError(
|
|
f"Too many allowed token IDs: {num_allowed_token_ids}. "
|
|
f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
|
|
)
|
|
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
|
|
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
|
|
use_logit_bias = True
|
|
else:
|
|
self.num_allowed_token_ids.np[req_idx] = 0
|
|
|
|
# Logit bias.
|
|
logit_bias = sampling_params.logit_bias
|
|
if logit_bias:
|
|
num_logit_bias = len(logit_bias)
|
|
if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
|
|
raise ValueError(
|
|
f"Too many logit bias tokens: {num_logit_bias}. "
|
|
f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
|
|
)
|
|
self.num_logit_bias.np[req_idx] = num_logit_bias
|
|
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
|
|
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
|
|
use_logit_bias = True
|
|
else:
|
|
self.num_logit_bias.np[req_idx] = 0
|
|
|
|
# Min tokens.
|
|
min_tokens = sampling_params.min_tokens
|
|
min_len = prompt_len + min_tokens
|
|
self.min_lens.np[req_idx] = min_len
|
|
stop_token_ids = sampling_params.all_stop_token_ids
|
|
if min_tokens > 0 and stop_token_ids:
|
|
num_stop_token_ids = len(stop_token_ids)
|
|
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
|
|
raise ValueError(
|
|
f"Too many stop tokens: {num_stop_token_ids}. "
|
|
f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
|
|
)
|
|
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
|
|
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
|
|
use_logit_bias = True
|
|
else:
|
|
self.num_stop_token_ids.np[req_idx] = 0
|
|
|
|
self.use_logit_bias[req_idx] = use_logit_bias
|
|
|
|
def apply_staged_writes(self) -> None:
|
|
self.num_allowed_token_ids.copy_to_uva()
|
|
self.allowed_token_ids.apply_write()
|
|
|
|
self.num_logit_bias.copy_to_uva()
|
|
self.logit_bias_token_ids.apply_write()
|
|
self.logit_bias.apply_write()
|
|
|
|
self.min_lens.copy_to_uva()
|
|
self.num_stop_token_ids.copy_to_uva()
|
|
self.stop_token_ids.apply_write()
|
|
|
|
def apply_logit_bias(
|
|
self,
|
|
logits: torch.Tensor,
|
|
idx_mapping: torch.Tensor,
|
|
idx_mapping_np: np.ndarray,
|
|
pos: torch.Tensor,
|
|
) -> None:
|
|
if not np.any(self.use_logit_bias[idx_mapping_np]):
|
|
# No request uses logit bias. Skip the kernel launch.
|
|
return
|
|
|
|
apply_logit_bias(
|
|
logits,
|
|
idx_mapping,
|
|
pos,
|
|
self.num_allowed_token_ids.gpu,
|
|
self.allowed_token_ids.gpu,
|
|
self.num_logit_bias.gpu,
|
|
self.logit_bias_token_ids.gpu,
|
|
self.logit_bias.gpu,
|
|
self.min_lens.gpu,
|
|
self.num_stop_token_ids.gpu,
|
|
self.stop_token_ids.gpu,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _bias_kernel(
|
|
logits_ptr,
|
|
logits_stride,
|
|
vocab_size,
|
|
idx_mapping_ptr,
|
|
# Allowed token IDs.
|
|
num_allowed_token_ids_ptr,
|
|
allowed_token_ids_ptr,
|
|
allowed_token_ids_stride,
|
|
# Logit bias.
|
|
num_logit_bias_ptr,
|
|
bias_token_ids_ptr,
|
|
bias_token_ids_stride,
|
|
bias_ptr,
|
|
bias_stride,
|
|
# Min tokens.
|
|
pos_ptr,
|
|
min_lens_ptr,
|
|
num_stop_token_ids_ptr,
|
|
stop_token_ids_ptr,
|
|
stop_token_ids_stride,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
LOGITS_BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
batch_idx = tl.program_id(0)
|
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
|
|
|
block = tl.arange(0, BLOCK_SIZE)
|
|
|
|
# Allowed token IDs.
|
|
num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
|
|
if num_allowed_token_ids > 0:
|
|
block = tl.arange(0, BLOCK_SIZE)
|
|
mask = block < num_allowed_token_ids
|
|
|
|
# Save logits for allowed token IDs.
|
|
allowed_token_ids = tl.load(
|
|
allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
|
|
mask=mask,
|
|
)
|
|
logits = tl.load(
|
|
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
|
)
|
|
|
|
# Set logits to -inf for all tokens.
|
|
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
|
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
|
tl.store(
|
|
logits_ptr + batch_idx * logits_stride + offset,
|
|
-float("inf"),
|
|
mask=offset < vocab_size,
|
|
)
|
|
|
|
# Restore logits for allowed token IDs.
|
|
tl.store(
|
|
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
|
logits,
|
|
mask=mask,
|
|
)
|
|
|
|
# Logit bias.
|
|
num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
|
|
if num_logit_bias > 0:
|
|
mask = block < num_logit_bias
|
|
token_ids = tl.load(
|
|
bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
|
|
mask=mask,
|
|
)
|
|
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
|
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
|
logits += bias
|
|
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
|
|
|
# Apply min tokens.
|
|
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
|
pos = tl.load(pos_ptr + batch_idx)
|
|
min_len = tl.load(min_lens_ptr + req_state_idx)
|
|
if num_stop_token_ids > 0 and pos < min_len:
|
|
mask = block < num_stop_token_ids
|
|
stop_token_ids = tl.load(
|
|
stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
|
|
mask=mask,
|
|
)
|
|
tl.store(
|
|
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
|
-float("inf"),
|
|
mask=mask,
|
|
)
|
|
|
|
|
|
def apply_logit_bias(
|
|
logits: torch.Tensor,
|
|
idx_mapping: torch.Tensor,
|
|
pos: torch.Tensor,
|
|
num_allowed_token_ids: torch.Tensor,
|
|
allowed_token_ids: torch.Tensor,
|
|
num_logit_bias: torch.Tensor,
|
|
logit_bias_token_ids: torch.Tensor,
|
|
logit_bias: torch.Tensor,
|
|
min_lens: torch.Tensor,
|
|
num_stop_token_ids: torch.Tensor,
|
|
stop_token_ids: torch.Tensor,
|
|
) -> None:
|
|
num_reqs, vocab_size = logits.shape
|
|
BLOCK_SIZE = triton.next_power_of_2(
|
|
max(
|
|
allowed_token_ids.shape[-1],
|
|
logit_bias_token_ids.shape[-1],
|
|
stop_token_ids.shape[-1],
|
|
)
|
|
)
|
|
LOGITS_BLOCK_SIZE = 8192
|
|
_bias_kernel[(num_reqs,)](
|
|
logits,
|
|
logits.stride(0),
|
|
vocab_size,
|
|
idx_mapping,
|
|
num_allowed_token_ids,
|
|
allowed_token_ids,
|
|
allowed_token_ids.stride(0),
|
|
num_logit_bias,
|
|
logit_bias_token_ids,
|
|
logit_bias_token_ids.stride(0),
|
|
logit_bias,
|
|
logit_bias.stride(0),
|
|
pos,
|
|
min_lens,
|
|
num_stop_token_ids,
|
|
stop_token_ids,
|
|
stop_token_ids.stride(0),
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
|
|
)
|