150 lines
4.3 KiB
Python
150 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
@triton.jit
|
|
def _temperature_kernel(
|
|
logits_ptr,
|
|
logits_stride,
|
|
idx_mapping_ptr,
|
|
temperature_ptr,
|
|
vocab_size,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
batch_idx = tl.program_id(0)
|
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
|
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
|
if temperature == 0.0 or temperature == 1.0:
|
|
# Early return to avoid loading logits.
|
|
return
|
|
|
|
block_idx = tl.program_id(1)
|
|
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = block < vocab_size
|
|
|
|
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
|
logits = logits.to(tl.float32)
|
|
logits = logits / temperature
|
|
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
|
|
|
|
|
def apply_temperature(
|
|
logits: torch.Tensor,
|
|
idx_mapping: torch.Tensor,
|
|
temperature: torch.Tensor,
|
|
) -> None:
|
|
num_reqs, vocab_size = logits.shape
|
|
BLOCK_SIZE = 8192
|
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
|
_temperature_kernel[(num_reqs, num_blocks)](
|
|
logits,
|
|
logits.stride(0),
|
|
idx_mapping,
|
|
temperature,
|
|
vocab_size,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _gumbel_sample_kernel(
|
|
local_argmax_ptr,
|
|
local_argmax_stride,
|
|
local_max_ptr,
|
|
local_max_stride,
|
|
logits_ptr,
|
|
logits_stride,
|
|
idx_mapping_ptr,
|
|
seeds_ptr,
|
|
pos_ptr,
|
|
temp_ptr,
|
|
vocab_size,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
APPLY_TEMPERATURE: tl.constexpr,
|
|
):
|
|
batch_idx = tl.program_id(0)
|
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
|
|
|
block_idx = tl.program_id(1)
|
|
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = block < vocab_size
|
|
logits = tl.load(
|
|
logits_ptr + batch_idx * logits_stride + block,
|
|
mask=mask,
|
|
other=float("-inf"),
|
|
)
|
|
logits = logits.to(tl.float32)
|
|
|
|
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
|
if temp != 0.0:
|
|
# Calculate the seed for gumbel noise.
|
|
seed = tl.load(seeds_ptr + req_state_idx)
|
|
pos = tl.load(pos_ptr + batch_idx)
|
|
gumbel_seed = tl.randint(seed, pos)
|
|
|
|
# Generate gumbel noise in FP32.
|
|
u = tl.rand(gumbel_seed, block)
|
|
u = tl.maximum(u, 1e-7)
|
|
gumbel_noise = -tl.log(-tl.log(u))
|
|
|
|
# Apply temperature.
|
|
if APPLY_TEMPERATURE:
|
|
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
|
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
|
logits = logits / temp
|
|
|
|
# Apply gumbel noise.
|
|
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
|
|
|
value, idx = tl.max(logits, axis=0, return_indices=True)
|
|
token_id = block_idx * BLOCK_SIZE + idx
|
|
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
|
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
|
|
|
|
|
def gumbel_sample(
|
|
logits: torch.Tensor, # [num_reqs, vocab_size]
|
|
idx_mapping: torch.Tensor, # [max_num_reqs]
|
|
temperature: torch.Tensor, # [max_num_reqs]
|
|
seed: torch.Tensor, # [max_num_reqs]
|
|
pos: torch.Tensor, # [num_reqs]
|
|
apply_temperature: bool,
|
|
) -> torch.Tensor:
|
|
num_reqs, vocab_size = logits.shape
|
|
BLOCK_SIZE = 1024
|
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
|
local_argmax = torch.empty(
|
|
num_reqs,
|
|
num_blocks,
|
|
dtype=torch.int64,
|
|
device=logits.device,
|
|
)
|
|
local_max = torch.empty(
|
|
num_reqs,
|
|
num_blocks,
|
|
dtype=torch.float32,
|
|
device=logits.device,
|
|
)
|
|
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
|
local_argmax,
|
|
local_argmax.stride(0),
|
|
local_max,
|
|
local_max.stride(0),
|
|
logits,
|
|
logits.stride(0),
|
|
idx_mapping,
|
|
seed,
|
|
pos,
|
|
temperature,
|
|
vocab_size,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
APPLY_TEMPERATURE=apply_temperature,
|
|
)
|
|
# NOTE(woosuk): Use int64 for later indexing.
|
|
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
|
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
|
return sampled
|