Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
280
vllm/v1/worker/gpu/sample/logit_bias.py
Normal file
280
vllm/v1/worker/gpu/sample/logit_bias.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
MAX_NUM_ALLOWED_TOKEN_IDS = 1024
|
||||
MAX_NUM_LOGIT_BIAS_TOKENS = 1024
|
||||
MAX_NUM_STOP_TOKEN_IDS = 128
|
||||
|
||||
|
||||
class LogitBiasState:
|
||||
def __init__(self, max_num_reqs: int, device: torch.device):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
|
||||
# Allowed token IDs.
|
||||
self.num_allowed_token_ids = UvaBackedTensor(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
self.allowed_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# Logit bias.
|
||||
self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.logit_bias_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
# Min tokens.
|
||||
self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.stop_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Using any of the above.
|
||||
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
def add_request(
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
# Using any logit bias.
|
||||
use_logit_bias = False
|
||||
|
||||
# Allowed token IDs.
|
||||
allowed_token_ids = sampling_params.allowed_token_ids
|
||||
if allowed_token_ids:
|
||||
num_allowed_token_ids = len(allowed_token_ids)
|
||||
if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
f"Too many allowed token IDs: {num_allowed_token_ids}. "
|
||||
f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
|
||||
)
|
||||
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
|
||||
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_allowed_token_ids.np[req_idx] = 0
|
||||
|
||||
# Logit bias.
|
||||
logit_bias = sampling_params.logit_bias
|
||||
if logit_bias:
|
||||
num_logit_bias = len(logit_bias)
|
||||
if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
|
||||
raise ValueError(
|
||||
f"Too many logit bias tokens: {num_logit_bias}. "
|
||||
f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
|
||||
)
|
||||
self.num_logit_bias.np[req_idx] = num_logit_bias
|
||||
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
|
||||
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_logit_bias.np[req_idx] = 0
|
||||
|
||||
# Min tokens.
|
||||
min_tokens = sampling_params.min_tokens
|
||||
min_len = prompt_len + min_tokens
|
||||
self.min_lens.np[req_idx] = min_len
|
||||
stop_token_ids = sampling_params.all_stop_token_ids
|
||||
if min_tokens > 0 and stop_token_ids:
|
||||
num_stop_token_ids = len(stop_token_ids)
|
||||
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
f"Too many stop tokens: {num_stop_token_ids}. "
|
||||
f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
|
||||
)
|
||||
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
|
||||
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_stop_token_ids.np[req_idx] = 0
|
||||
|
||||
self.use_logit_bias[req_idx] = use_logit_bias
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_allowed_token_ids.copy_to_uva()
|
||||
self.allowed_token_ids.apply_write()
|
||||
|
||||
self.num_logit_bias.copy_to_uva()
|
||||
self.logit_bias_token_ids.apply_write()
|
||||
self.logit_bias.apply_write()
|
||||
|
||||
self.min_lens.copy_to_uva()
|
||||
self.num_stop_token_ids.copy_to_uva()
|
||||
self.stop_token_ids.apply_write()
|
||||
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
if not np.any(self.use_logit_bias[idx_mapping_np]):
|
||||
# No request uses logit bias. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
self.num_logit_bias.gpu,
|
||||
self.logit_bias_token_ids.gpu,
|
||||
self.logit_bias.gpu,
|
||||
self.min_lens.gpu,
|
||||
self.num_stop_token_ids.gpu,
|
||||
self.stop_token_ids.gpu,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
allowed_token_ids_stride,
|
||||
# Logit bias.
|
||||
num_logit_bias_ptr,
|
||||
bias_token_ids_ptr,
|
||||
bias_token_ids_stride,
|
||||
bias_ptr,
|
||||
bias_stride,
|
||||
# Min tokens.
|
||||
pos_ptr,
|
||||
min_lens_ptr,
|
||||
num_stop_token_ids_ptr,
|
||||
stop_token_ids_ptr,
|
||||
stop_token_ids_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
|
||||
if num_allowed_token_ids > 0:
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_allowed_token_ids
|
||||
|
||||
# Save logits for allowed token IDs.
|
||||
allowed_token_ids = tl.load(
|
||||
allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Logit bias.
|
||||
num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
|
||||
if num_logit_bias > 0:
|
||||
mask = block < num_logit_bias
|
||||
token_ids = tl.load(
|
||||
bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
stop_token_ids = tl.load(
|
||||
stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
num_logit_bias: torch.Tensor,
|
||||
logit_bias_token_ids: torch.Tensor,
|
||||
logit_bias: torch.Tensor,
|
||||
min_lens: torch.Tensor,
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
logit_bias_token_ids.shape[-1],
|
||||
stop_token_ids.shape[-1],
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
num_logit_bias,
|
||||
logit_bias_token_ids,
|
||||
logit_bias_token_ids.stride(0),
|
||||
logit_bias,
|
||||
logit_bias.stride(0),
|
||||
pos,
|
||||
min_lens,
|
||||
num_stop_token_ids,
|
||||
stop_token_ids,
|
||||
stop_token_ids.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
|
||||
)
|
||||
Reference in New Issue
Block a user