# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np import torch from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor from vllm.v1.worker.gpu.states import RequestState class PenaltiesState: def __init__(self, req_states: RequestState): self.req_states = req_states max_num_reqs = req_states.max_num_reqs self.vocab_size = req_states.vocab_size self.device = req_states.device self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.use_penalty = np.zeros(max_num_reqs, dtype=bool) # Initialize repetition penalty manually because 0 is an invalid value for it. self.repetition_penalty.np.fill(1.0) self.repetition_penalty.copy_to_uva() # Statistics for penalties. self.prompt_bin_mask = torch.zeros( max_num_reqs, cdiv(self.vocab_size, 32), dtype=torch.int32, device=self.device, ) # TODO(woosuk): This tensor is rarely used but can be very large, taking up # GBs of GPU memory. Optimize the memory usage. self.output_bin_counts = torch.zeros( max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device ) self._new_penalties_reqs: list[int] = [] def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty self.presence_penalty.np[req_idx] = sampling_params.presence_penalty do_penalty = use_penalty(sampling_params) self.use_penalty[req_idx] = do_penalty if do_penalty: self._new_penalties_reqs.append(req_idx) def apply_staged_writes(self) -> None: if self._new_penalties_reqs: idx_mapping = async_tensor_h2d( self._new_penalties_reqs, dtype=torch.int32, target_device=self.device, pin_memory=True, ) prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs] max_prefill_len = int(prefill_lens.max()) bincount( idx_mapping, self.req_states.all_token_ids.gpu, self.req_states.prompt_len.gpu, self.req_states.prefill_len.gpu, self.prompt_bin_mask, self.output_bin_counts, max_prefill_len, ) self._new_penalties_reqs.clear() self.repetition_penalty.copy_to_uva() self.frequency_penalty.copy_to_uva() self.presence_penalty.copy_to_uva() def apply_penalties( self, logits: torch.Tensor, idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, num_speculative_tokens: int, ) -> None: if not np.any(self.use_penalty[idx_mapping_np]): # No request uses penalties. Skip the kernel launch. return apply_penalties( logits, idx_mapping, input_ids, expanded_local_pos, self.repetition_penalty.gpu, self.frequency_penalty.gpu, self.presence_penalty.gpu, self.prompt_bin_mask, self.output_bin_counts, num_speculative_tokens, ) @triton.jit def _penalties_kernel( logits_ptr, logits_stride, idx_mapping_ptr, token_ids_ptr, expanded_local_pos_ptr, repetition_penalty_ptr, frequency_penalty_ptr, presence_penalty_ptr, prompt_bin_mask_ptr, prompt_bin_mask_stride, output_bin_counts_ptr, output_bin_counts_stride, vocab_size, BLOCK_SIZE: tl.constexpr, MAX_SPEC_LEN: tl.constexpr, ): token_idx = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + token_idx) rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx) freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx) pres_penalty = tl.load(presence_penalty_ptr + req_state_idx) use_rep_penalty = rep_penalty != 1.0 use_freq_penalty = freq_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0 use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty if not use_penalty: # Early return to avoid loading logits. return block_idx = tl.program_id(1) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = block < vocab_size logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask) logits = logits.to(tl.float32) base_output_counts = tl.load( output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, mask=mask, other=0, ) # Compute cumulative draft_counts from previous positions in this request pos = tl.load(expanded_local_pos_ptr + token_idx) start_idx = token_idx - pos draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32) for prev_pos in tl.static_range(MAX_SPEC_LEN): if prev_pos < pos: prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1) token_match = block == prev_token draft_counts = draft_counts + token_match.to(tl.int32) # Total counts = base output counts + cumulative draft counts output_bin_counts = base_output_counts + draft_counts output_bin_mask = output_bin_counts > 0 # Apply repetition penalties. if use_rep_penalty: packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32) packed_mask = tl.load( prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block, mask=packed_block < tl.cdiv(vocab_size, 32), other=0, ) prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1 prompt_bin_mask = prompt_bin_mask.to(tl.int1) prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0) # If logits are positive, divide by penalty, otherwise multiply by penalty. logits *= tl.where(logits > 0, 1.0 / scale, scale) # Apply frequency penalties. logits -= freq_penalty * output_bin_counts # Apply presence penalties. logits -= pres_penalty * output_bin_mask # Store back to logits. tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask) def apply_penalties( logits: torch.Tensor, idx_mapping: torch.Tensor, token_ids: torch.Tensor, expanded_local_pos: torch.Tensor, repetition_penalty: torch.Tensor, frequency_penalty: torch.Tensor, presence_penalty: torch.Tensor, prompt_bin_mask: torch.Tensor, output_bin_counts: torch.Tensor, num_speculative_tokens: int, ) -> None: num_tokens, vocab_size = logits.shape BLOCK_SIZE = 8192 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) _penalties_kernel[(num_tokens, num_blocks)]( logits, logits.stride(0), idx_mapping, token_ids, expanded_local_pos, repetition_penalty, frequency_penalty, presence_penalty, prompt_bin_mask, prompt_bin_mask.stride(0), output_bin_counts, output_bin_counts.stride(0), vocab_size, BLOCK_SIZE=BLOCK_SIZE, MAX_SPEC_LEN=num_speculative_tokens, ) @triton.jit def _bincount_kernel( idx_mapping_ptr, all_token_ids_ptr, all_token_ids_stride, prompt_len_ptr, prefill_len_ptr, prompt_bin_mask_ptr, prompt_bin_mask_stride, output_bin_counts_ptr, output_bin_counts_stride, BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) block_idx = tl.program_id(1) req_state_idx = tl.load(idx_mapping_ptr + batch_idx) prefill_len = tl.load(prefill_len_ptr + req_state_idx) if block_idx * BLOCK_SIZE >= prefill_len: return prompt_len = tl.load(prompt_len_ptr + req_state_idx) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) if block_idx * BLOCK_SIZE < prompt_len: mask = block < prompt_len prompt_tokens = tl.load( all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask ) idx = prompt_tokens // 32 bit_idx = prompt_tokens % 32 bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx tl.atomic_or( prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx, bit, mask=mask, ) if (block_idx + 1) * BLOCK_SIZE >= prompt_len: mask = block < prefill_len mask &= block >= prompt_len output_tokens = tl.load( all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask ) tl.atomic_add( output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + output_tokens, 1, mask=mask, ) def bincount( idx_mapping: torch.Tensor, all_token_ids: torch.Tensor, prompt_len: torch.Tensor, prefill_len: torch.Tensor, prompt_bin_mask: torch.Tensor, output_bin_counts: torch.Tensor, max_prefill_len: int, ) -> None: prompt_bin_mask[idx_mapping] = 0 output_bin_counts[idx_mapping] = 0 num_reqs = idx_mapping.shape[0] BLOCK_SIZE = 1024 num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE) _bincount_kernel[(num_reqs, num_blocks)]( idx_mapping, all_token_ids, all_token_ids.stride(0), prompt_len, prefill_len, prompt_bin_mask, prompt_bin_mask.stride(0), output_bin_counts, output_bin_counts.stride(0), BLOCK_SIZE=BLOCK_SIZE, ) def use_penalty(sampling_params: SamplingParams) -> bool: return ( sampling_params.repetition_penalty != 1.0 or sampling_params.frequency_penalty != 0.0 or sampling_params.presence_penalty != 0.0 )