# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np import torch from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor MAX_NUM_ALLOWED_TOKEN_IDS = 1024 MAX_NUM_LOGIT_BIAS_TOKENS = 1024 MAX_NUM_STOP_TOKEN_IDS = 128 class LogitBiasState: def __init__(self, max_num_reqs: int, device: torch.device): self.max_num_reqs = max_num_reqs # Allowed token IDs. self.num_allowed_token_ids = UvaBackedTensor( self.max_num_reqs, dtype=torch.int32 ) self.allowed_token_ids = StagedWriteTensor( (self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS), dtype=torch.int32, device=device, ) # Logit bias. self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) self.logit_bias_token_ids = StagedWriteTensor( (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS), dtype=torch.int32, device=device, ) self.logit_bias = StagedWriteTensor( (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS), dtype=torch.float32, device=device, ) # Min tokens. self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) self.stop_token_ids = StagedWriteTensor( (self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS), dtype=torch.int32, device=device, ) # Using any of the above. self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool) def add_request( self, req_idx: int, prompt_len: int, sampling_params: SamplingParams ) -> None: # Using any logit bias. use_logit_bias = False # Allowed token IDs. allowed_token_ids = sampling_params.allowed_token_ids if allowed_token_ids: num_allowed_token_ids = len(allowed_token_ids) if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS: raise ValueError( f"Too many allowed token IDs: {num_allowed_token_ids}. " f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}." ) self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids) use_logit_bias = True else: self.num_allowed_token_ids.np[req_idx] = 0 # Logit bias. logit_bias = sampling_params.logit_bias if logit_bias: num_logit_bias = len(logit_bias) if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS: raise ValueError( f"Too many logit bias tokens: {num_logit_bias}. " f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}." ) self.num_logit_bias.np[req_idx] = num_logit_bias self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys()) self.logit_bias.stage_write(req_idx, 0, logit_bias.values()) use_logit_bias = True else: self.num_logit_bias.np[req_idx] = 0 # Min tokens. min_tokens = sampling_params.min_tokens min_len = prompt_len + min_tokens self.min_lens.np[req_idx] = min_len stop_token_ids = sampling_params.all_stop_token_ids if min_tokens > 0 and stop_token_ids: num_stop_token_ids = len(stop_token_ids) if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS: raise ValueError( f"Too many stop tokens: {num_stop_token_ids}. " f"The max size is {MAX_NUM_STOP_TOKEN_IDS}." ) self.num_stop_token_ids.np[req_idx] = num_stop_token_ids self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids) use_logit_bias = True else: self.num_stop_token_ids.np[req_idx] = 0 self.use_logit_bias[req_idx] = use_logit_bias def apply_staged_writes(self) -> None: self.num_allowed_token_ids.copy_to_uva() self.allowed_token_ids.apply_write() self.num_logit_bias.copy_to_uva() self.logit_bias_token_ids.apply_write() self.logit_bias.apply_write() self.min_lens.copy_to_uva() self.num_stop_token_ids.copy_to_uva() self.stop_token_ids.apply_write() def apply_logit_bias( self, logits: torch.Tensor, idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray, pos: torch.Tensor, ) -> None: if not np.any(self.use_logit_bias[idx_mapping_np]): # No request uses logit bias. Skip the kernel launch. return apply_logit_bias( logits, idx_mapping, pos, self.num_allowed_token_ids.gpu, self.allowed_token_ids.gpu, self.num_logit_bias.gpu, self.logit_bias_token_ids.gpu, self.logit_bias.gpu, self.min_lens.gpu, self.num_stop_token_ids.gpu, self.stop_token_ids.gpu, ) @triton.jit def _bias_kernel( logits_ptr, logits_stride, vocab_size, idx_mapping_ptr, # Allowed token IDs. num_allowed_token_ids_ptr, allowed_token_ids_ptr, allowed_token_ids_stride, # Logit bias. num_logit_bias_ptr, bias_token_ids_ptr, bias_token_ids_stride, bias_ptr, bias_stride, # Min tokens. pos_ptr, min_lens_ptr, num_stop_token_ids_ptr, stop_token_ids_ptr, stop_token_ids_stride, BLOCK_SIZE: tl.constexpr, LOGITS_BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + batch_idx) block = tl.arange(0, BLOCK_SIZE) # Allowed token IDs. num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx) if num_allowed_token_ids > 0: block = tl.arange(0, BLOCK_SIZE) mask = block < num_allowed_token_ids # Save logits for allowed token IDs. allowed_token_ids = tl.load( allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block, mask=mask, ) logits = tl.load( logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask ) # Set logits to -inf for all tokens. for i in range(0, vocab_size, LOGITS_BLOCK_SIZE): offset = i + tl.arange(0, LOGITS_BLOCK_SIZE) tl.store( logits_ptr + batch_idx * logits_stride + offset, -float("inf"), mask=offset < vocab_size, ) # Restore logits for allowed token IDs. tl.store( logits_ptr + batch_idx * logits_stride + allowed_token_ids, logits, mask=mask, ) # Logit bias. num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx) if num_logit_bias > 0: mask = block < num_logit_bias token_ids = tl.load( bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block, mask=mask, ) bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask) logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask) logits += bias tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask) # Apply min tokens. num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx) pos = tl.load(pos_ptr + batch_idx) min_len = tl.load(min_lens_ptr + req_state_idx) if num_stop_token_ids > 0 and pos < min_len: mask = block < num_stop_token_ids stop_token_ids = tl.load( stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block, mask=mask, ) tl.store( logits_ptr + batch_idx * logits_stride + stop_token_ids, -float("inf"), mask=mask, ) def apply_logit_bias( logits: torch.Tensor, idx_mapping: torch.Tensor, pos: torch.Tensor, num_allowed_token_ids: torch.Tensor, allowed_token_ids: torch.Tensor, num_logit_bias: torch.Tensor, logit_bias_token_ids: torch.Tensor, logit_bias: torch.Tensor, min_lens: torch.Tensor, num_stop_token_ids: torch.Tensor, stop_token_ids: torch.Tensor, ) -> None: num_reqs, vocab_size = logits.shape BLOCK_SIZE = triton.next_power_of_2( max( allowed_token_ids.shape[-1], logit_bias_token_ids.shape[-1], stop_token_ids.shape[-1], ) ) LOGITS_BLOCK_SIZE = 8192 _bias_kernel[(num_reqs,)]( logits, logits.stride(0), vocab_size, idx_mapping, num_allowed_token_ids, allowed_token_ids, allowed_token_ids.stride(0), num_logit_bias, logit_bias_token_ids, logit_bias_token_ids.stride(0), logit_bias, logit_bias.stride(0), pos, min_lens, num_stop_token_ids, stop_token_ids, stop_token_ids.stride(0), BLOCK_SIZE=BLOCK_SIZE, LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE, )