update
This commit is contained in:
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
194
vllm/v1/worker/gpu/sample/bad_words.py
Normal file
194
vllm/v1/worker/gpu/sample/bad_words.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
|
||||
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
|
||||
|
||||
|
||||
class BadWordsState:
|
||||
def __init__(self, req_states: RequestState):
|
||||
self.req_states = req_states
|
||||
self.max_num_reqs = req_states.max_num_reqs
|
||||
self.device = req_states.device
|
||||
|
||||
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
|
||||
self.bad_word_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1]
|
||||
self.bad_word_offsets = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_BAD_WORDS + 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# number of bad words per request
|
||||
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
bad_words_token_ids = sampling_params.bad_words_token_ids
|
||||
if not bad_words_token_ids:
|
||||
self.num_bad_words.np[req_idx] = 0
|
||||
return
|
||||
|
||||
num_bad_words = len(bad_words_token_ids)
|
||||
if num_bad_words > MAX_NUM_BAD_WORDS:
|
||||
raise ValueError(
|
||||
f"Too many bad words: {num_bad_words}. "
|
||||
f"The max number is {MAX_NUM_BAD_WORDS}."
|
||||
)
|
||||
|
||||
# Flatten bad words and compute offsets
|
||||
flattened_tokens: list[int] = []
|
||||
offsets: list[int] = [0]
|
||||
for bad_word in bad_words_token_ids:
|
||||
flattened_tokens.extend(bad_word)
|
||||
offsets.append(len(flattened_tokens))
|
||||
|
||||
if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS:
|
||||
raise ValueError(
|
||||
f"Too many total bad word tokens: {len(flattened_tokens)}. "
|
||||
f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}."
|
||||
)
|
||||
|
||||
# Stage writes
|
||||
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
|
||||
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
|
||||
self.num_bad_words.np[req_idx] = num_bad_words
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_bad_words.copy_to_uva()
|
||||
self.bad_word_token_ids.apply_write()
|
||||
self.bad_word_offsets.apply_write()
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> None:
|
||||
max_num_bad_words = int(self.num_bad_words.np[idx_mapping_np].max())
|
||||
if max_num_bad_words == 0:
|
||||
# No request uses bad words. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prompt_len.gpu,
|
||||
self.req_states.total_len.gpu,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
max_num_bad_words,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bad_words_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
expanded_idx_mapping_ptr,
|
||||
bad_word_token_ids_ptr,
|
||||
bad_word_token_ids_stride,
|
||||
bad_word_offsets_ptr,
|
||||
bad_word_offsets_stride,
|
||||
num_bad_words_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
total_len_ptr,
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
output_len = total_len - prompt_len
|
||||
effective_len = output_len + pos
|
||||
|
||||
bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride
|
||||
bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride
|
||||
output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len
|
||||
|
||||
start = tl.load(bd_offsets_base + bw_idx)
|
||||
end = tl.load(bd_offsets_base + bw_idx + 1)
|
||||
bad_word_len = end - start
|
||||
prefix_len = bad_word_len - 1
|
||||
|
||||
if prefix_len > effective_len:
|
||||
return
|
||||
|
||||
last_token = tl.load(bd_tokens_base + end - 1)
|
||||
match = 1
|
||||
for i in range(prefix_len):
|
||||
expected = tl.load(bd_tokens_base + start + i)
|
||||
actual_pos = effective_len - prefix_len + i
|
||||
|
||||
from_spec_input = actual_pos >= output_len
|
||||
if from_spec_input:
|
||||
spec_offset = actual_pos - output_len
|
||||
actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset)
|
||||
else:
|
||||
actual = tl.load(output_base + actual_pos)
|
||||
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
bad_word_token_ids: torch.Tensor,
|
||||
bad_word_offsets: torch.Tensor,
|
||||
num_bad_words: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
bad_word_token_ids,
|
||||
bad_word_token_ids.stride(0),
|
||||
bad_word_offsets,
|
||||
bad_word_offsets.stride(0),
|
||||
num_bad_words,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
total_len,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
149
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
149
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
temperature_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
||||
if temperature == 0.0 or temperature == 1.0:
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits / temperature
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_temperature_kernel[(num_reqs, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
u = tl.rand(gumbel_seed, block)
|
||||
u = tl.maximum(u, 1e-7)
|
||||
gumbel_noise = -tl.log(-tl.log(u))
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||
logits = logits / temp
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
280
vllm/v1/worker/gpu/sample/logit_bias.py
Normal file
280
vllm/v1/worker/gpu/sample/logit_bias.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
MAX_NUM_ALLOWED_TOKEN_IDS = 1024
|
||||
MAX_NUM_LOGIT_BIAS_TOKENS = 1024
|
||||
MAX_NUM_STOP_TOKEN_IDS = 128
|
||||
|
||||
|
||||
class LogitBiasState:
|
||||
def __init__(self, max_num_reqs: int, device: torch.device):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
|
||||
# Allowed token IDs.
|
||||
self.num_allowed_token_ids = UvaBackedTensor(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
self.allowed_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# Logit bias.
|
||||
self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.logit_bias_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
# Min tokens.
|
||||
self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.stop_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Using any of the above.
|
||||
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
def add_request(
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
# Using any logit bias.
|
||||
use_logit_bias = False
|
||||
|
||||
# Allowed token IDs.
|
||||
allowed_token_ids = sampling_params.allowed_token_ids
|
||||
if allowed_token_ids:
|
||||
num_allowed_token_ids = len(allowed_token_ids)
|
||||
if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
f"Too many allowed token IDs: {num_allowed_token_ids}. "
|
||||
f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
|
||||
)
|
||||
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
|
||||
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_allowed_token_ids.np[req_idx] = 0
|
||||
|
||||
# Logit bias.
|
||||
logit_bias = sampling_params.logit_bias
|
||||
if logit_bias:
|
||||
num_logit_bias = len(logit_bias)
|
||||
if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
|
||||
raise ValueError(
|
||||
f"Too many logit bias tokens: {num_logit_bias}. "
|
||||
f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
|
||||
)
|
||||
self.num_logit_bias.np[req_idx] = num_logit_bias
|
||||
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
|
||||
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_logit_bias.np[req_idx] = 0
|
||||
|
||||
# Min tokens.
|
||||
min_tokens = sampling_params.min_tokens
|
||||
min_len = prompt_len + min_tokens
|
||||
self.min_lens.np[req_idx] = min_len
|
||||
stop_token_ids = sampling_params.all_stop_token_ids
|
||||
if min_tokens > 0 and stop_token_ids:
|
||||
num_stop_token_ids = len(stop_token_ids)
|
||||
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
f"Too many stop tokens: {num_stop_token_ids}. "
|
||||
f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
|
||||
)
|
||||
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
|
||||
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_stop_token_ids.np[req_idx] = 0
|
||||
|
||||
self.use_logit_bias[req_idx] = use_logit_bias
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_allowed_token_ids.copy_to_uva()
|
||||
self.allowed_token_ids.apply_write()
|
||||
|
||||
self.num_logit_bias.copy_to_uva()
|
||||
self.logit_bias_token_ids.apply_write()
|
||||
self.logit_bias.apply_write()
|
||||
|
||||
self.min_lens.copy_to_uva()
|
||||
self.num_stop_token_ids.copy_to_uva()
|
||||
self.stop_token_ids.apply_write()
|
||||
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
if not np.any(self.use_logit_bias[idx_mapping_np]):
|
||||
# No request uses logit bias. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
self.num_logit_bias.gpu,
|
||||
self.logit_bias_token_ids.gpu,
|
||||
self.logit_bias.gpu,
|
||||
self.min_lens.gpu,
|
||||
self.num_stop_token_ids.gpu,
|
||||
self.stop_token_ids.gpu,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
allowed_token_ids_stride,
|
||||
# Logit bias.
|
||||
num_logit_bias_ptr,
|
||||
bias_token_ids_ptr,
|
||||
bias_token_ids_stride,
|
||||
bias_ptr,
|
||||
bias_stride,
|
||||
# Min tokens.
|
||||
pos_ptr,
|
||||
min_lens_ptr,
|
||||
num_stop_token_ids_ptr,
|
||||
stop_token_ids_ptr,
|
||||
stop_token_ids_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
|
||||
if num_allowed_token_ids > 0:
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_allowed_token_ids
|
||||
|
||||
# Save logits for allowed token IDs.
|
||||
allowed_token_ids = tl.load(
|
||||
allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Logit bias.
|
||||
num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
|
||||
if num_logit_bias > 0:
|
||||
mask = block < num_logit_bias
|
||||
token_ids = tl.load(
|
||||
bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
stop_token_ids = tl.load(
|
||||
stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
num_logit_bias: torch.Tensor,
|
||||
logit_bias_token_ids: torch.Tensor,
|
||||
logit_bias: torch.Tensor,
|
||||
min_lens: torch.Tensor,
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
logit_bias_token_ids.shape[-1],
|
||||
stop_token_ids.shape[-1],
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
num_logit_bias,
|
||||
logit_bias_token_ids,
|
||||
logit_bias_token_ids.stride(0),
|
||||
logit_bias,
|
||||
logit_bias.stride(0),
|
||||
pos,
|
||||
min_lens,
|
||||
num_stop_token_ids,
|
||||
stop_token_ids,
|
||||
stop_token_ids.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
|
||||
)
|
||||
126
vllm/v1/worker/gpu/sample/logprob.py
Normal file
126
vllm/v1/worker/gpu/sample/logprob.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits >= x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor, token_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
cu_num_logits: list[int] | None = None,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
if num_logprobs > 0:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
cu_num_generated_tokens=cu_num_logits,
|
||||
)
|
||||
56
vllm/v1/worker/gpu/sample/min_p.py
Normal file
56
vllm/v1/worker/gpu/sample/min_p.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
threshold = max_val + tl.log(min_p)
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
14
vllm/v1/worker/gpu/sample/output.py
Normal file
14
vllm/v1/worker/gpu/sample/output.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
sampled_token_ids: torch.Tensor
|
||||
logprobs_tensors: LogprobsTensors | None
|
||||
num_nans: torch.Tensor | None
|
||||
311
vllm/v1/worker/gpu/sample/penalties.py
Normal file
311
vllm/v1/worker/gpu/sample/penalties.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import async_tensor_h2d
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class PenaltiesState:
|
||||
def __init__(self, req_states: RequestState):
|
||||
self.req_states = req_states
|
||||
|
||||
max_num_reqs = req_states.max_num_reqs
|
||||
self.vocab_size = req_states.vocab_size
|
||||
self.device = req_states.device
|
||||
|
||||
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
# Initialize repetition penalty manually because 0 is an invalid value for it.
|
||||
self.repetition_penalty.np.fill(1.0)
|
||||
self.repetition_penalty.copy_to_uva()
|
||||
|
||||
# Statistics for penalties.
|
||||
self.prompt_bin_mask = torch.zeros(
|
||||
max_num_reqs,
|
||||
cdiv(self.vocab_size, 32),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
|
||||
# GBs of GPU memory. Optimize the memory usage.
|
||||
self.output_bin_counts = torch.zeros(
|
||||
max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self._new_penalties_reqs: list[int] = []
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
|
||||
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
|
||||
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
|
||||
|
||||
do_penalty = use_penalty(sampling_params)
|
||||
self.use_penalty[req_idx] = do_penalty
|
||||
if do_penalty:
|
||||
self._new_penalties_reqs.append(req_idx)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self._new_penalties_reqs:
|
||||
idx_mapping = async_tensor_h2d(
|
||||
self._new_penalties_reqs,
|
||||
dtype=torch.int32,
|
||||
target_device=self.device,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
|
||||
max_prefill_len = int(prefill_lens.max())
|
||||
bincount(
|
||||
idx_mapping,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prompt_len.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.prompt_bin_mask,
|
||||
self.output_bin_counts,
|
||||
max_prefill_len,
|
||||
)
|
||||
self._new_penalties_reqs.clear()
|
||||
|
||||
self.repetition_penalty.copy_to_uva()
|
||||
self.frequency_penalty.copy_to_uva()
|
||||
self.presence_penalty.copy_to_uva()
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
num_speculative_tokens: int,
|
||||
) -> None:
|
||||
if not np.any(self.use_penalty[idx_mapping_np]):
|
||||
# No request uses penalties. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
self.frequency_penalty.gpu,
|
||||
self.presence_penalty.gpu,
|
||||
self.prompt_bin_mask,
|
||||
self.output_bin_counts,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
use_freq_penalty = freq_penalty != 0.0
|
||||
use_pres_penalty = pres_penalty != 0.0
|
||||
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
|
||||
if not use_penalty:
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
base_output_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
||||
mask=mask,
|
||||
other=0,
|
||||
)
|
||||
|
||||
# Compute cumulative draft_counts from previous positions in this request
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
start_idx = token_idx - pos
|
||||
draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
|
||||
for prev_pos in tl.static_range(MAX_SPEC_LEN):
|
||||
if prev_pos < pos:
|
||||
prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1)
|
||||
token_match = block == prev_token
|
||||
draft_counts = draft_counts + token_match.to(tl.int32)
|
||||
|
||||
# Total counts = base output counts + cumulative draft counts
|
||||
output_bin_counts = base_output_counts + draft_counts
|
||||
output_bin_mask = output_bin_counts > 0
|
||||
|
||||
# Apply repetition penalties.
|
||||
if use_rep_penalty:
|
||||
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
packed_mask = tl.load(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
|
||||
mask=packed_block < tl.cdiv(vocab_size, 32),
|
||||
other=0,
|
||||
)
|
||||
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
|
||||
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
|
||||
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
||||
|
||||
# Apply frequency penalties.
|
||||
logits -= freq_penalty * output_bin_counts
|
||||
# Apply presence penalties.
|
||||
logits -= pres_penalty * output_bin_mask
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
frequency_penalty: torch.Tensor,
|
||||
presence_penalty: torch.Tensor,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
num_speculative_tokens: int,
|
||||
) -> None:
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
prompt_bin_mask,
|
||||
prompt_bin_mask.stride(0),
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
MAX_SPEC_LEN=num_speculative_tokens,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
prefill_len_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
return
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prompt_tokens = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
|
||||
)
|
||||
idx = prompt_tokens // 32
|
||||
bit_idx = prompt_tokens % 32
|
||||
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
|
||||
tl.atomic_or(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
|
||||
bit,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
output_tokens = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
|
||||
)
|
||||
tl.atomic_add(
|
||||
output_bin_counts_ptr
|
||||
+ req_state_idx * output_bin_counts_stride
|
||||
+ output_tokens,
|
||||
1,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
prefill_len,
|
||||
prompt_bin_mask,
|
||||
prompt_bin_mask.stride(0),
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
def use_penalty(sampling_params: SamplingParams) -> bool:
|
||||
return (
|
||||
sampling_params.repetition_penalty != 1.0
|
||||
or sampling_params.frequency_penalty != 0.0
|
||||
or sampling_params.presence_penalty != 0.0
|
||||
)
|
||||
208
vllm/v1/worker/gpu/sample/prompt_logprob.py
Normal file
208
vllm/v1/worker/gpu/sample/prompt_logprob.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
|
||||
|
||||
class PromptLogprobsWorker:
|
||||
def __init__(self, max_num_reqs: int):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
|
||||
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
# req_idx -> list of in-progress LogprobsTensors
|
||||
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
|
||||
|
||||
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
|
||||
# For now, only support prompt logprobs for the prompt tokens (not top-k).
|
||||
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
|
||||
if uses_prompt_logprobs:
|
||||
self.in_progress_prompt_logprobs[req_id] = []
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.in_progress_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
# [max_num_reqs, max_model_len]
|
||||
all_token_ids: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
prompt_lens: np.ndarray,
|
||||
# [max_num_reqs]
|
||||
prefill_lens: np.ndarray,
|
||||
# [max_num_reqs]
|
||||
num_computed_prefill_tokens: np.ndarray,
|
||||
) -> dict[str, LogprobsTensors]:
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
# Common case: No request asks for prompt logprobs.
|
||||
return {}
|
||||
|
||||
prompt_lens = prompt_lens[idx_mapping_np]
|
||||
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
||||
# needed for prompt logprobs.
|
||||
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
|
||||
includes_prompt = computed_prefill < prompt_lens - 1
|
||||
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
||||
# logprobs must have been computed before preemption. Skip.
|
||||
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
|
||||
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
return {}
|
||||
|
||||
# Get the prompt logprobs token_ids.
|
||||
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
|
||||
input_batch.num_tokens,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.idx_mapping,
|
||||
num_computed_tokens,
|
||||
all_token_ids,
|
||||
)
|
||||
# Compute the prompt logprobs.
|
||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
|
||||
prompt_logprobs_token_ids,
|
||||
hidden_states[: input_batch.num_tokens],
|
||||
logits_fn,
|
||||
)
|
||||
|
||||
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
||||
is_prompt_chunked = pos_after_step < prompt_lens
|
||||
|
||||
query_start_loc_np = input_batch.query_start_loc_np
|
||||
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
|
||||
start_idx = query_start_loc_np[i]
|
||||
end_idx = query_start_loc_np[i + 1]
|
||||
assert start_idx < end_idx, (
|
||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||
)
|
||||
if not is_prompt_chunked[i]:
|
||||
end_idx -= 1
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
||||
logprobs=prompt_logprobs[start_idx:end_idx],
|
||||
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
||||
)
|
||||
|
||||
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
|
||||
if is_prompt_chunked[i]:
|
||||
# Prompt is chunked. Do not return the logprobs yet.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
continue
|
||||
|
||||
if prompt_logprobs_list:
|
||||
# Merge the in-progress logprobs.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=torch.cat(
|
||||
[x.logprob_token_ids for x in prompt_logprobs_list]
|
||||
),
|
||||
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
||||
selected_token_ranks=torch.cat(
|
||||
[x.selected_token_ranks for x in prompt_logprobs_list]
|
||||
),
|
||||
)
|
||||
prompt_logprobs_list.clear()
|
||||
|
||||
prompt_logprobs_dict[req_id] = logprobs
|
||||
return prompt_logprobs_dict
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prompt_logprobs_token_ids_kernel(
|
||||
prompt_logprobs_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
# NOTE(woosuk): We should shift the pos by one
|
||||
# because the logprob is computed for the next token.
|
||||
target_pos = num_computed_tokens + 1 + block
|
||||
token_ids = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_logprobs_token_ids(
|
||||
num_tokens: int,
|
||||
query_start_loc: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prompt_logprobs_token_ids_kernel[(num_reqs,)](
|
||||
token_ids,
|
||||
query_start_loc,
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
def compute_prompt_logprobs_with_chunking(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
155
vllm/v1/worker/gpu/sample/sampler.py
Normal file
155
vllm/v1/worker/gpu/sample/sampler.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
|
||||
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
req_states: RequestState,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
num_speculative_tokens: int = 1,
|
||||
):
|
||||
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
||||
|
||||
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
|
||||
self.penalties_state = PenaltiesState(req_states)
|
||||
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
||||
self.bad_words_state = BadWordsState(req_states)
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
|
||||
def add_request(
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
self.sampling_states.add_request(req_idx, sampling_params)
|
||||
self.penalties_state.add_request(req_idx, sampling_params)
|
||||
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
|
||||
self.bad_words_state.add_request(req_idx, sampling_params)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.sampling_states.apply_staged_writes()
|
||||
self.penalties_state.apply_staged_writes()
|
||||
self.logit_bias_state.apply_staged_writes()
|
||||
self.bad_words_state.apply_staged_writes()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
cu_num_logits_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
||||
# that num_nans is computed before applying penalties and temperature.
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
pos,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
||||
if max_num_logprobs != NO_LOGPROBS:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
logits = processed_logits
|
||||
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
|
||||
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits, max_num_logprobs, sampled, cu_num_logits
|
||||
)
|
||||
else:
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
num_nans=num_nans,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Copy logits to a new FP32 tensor.
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(
|
||||
logits, idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits
|
||||
104
vllm/v1/worker/gpu/sample/states.py
Normal file
104
vllm/v1/worker/gpu/sample/states.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
|
||||
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
||||
|
||||
NO_LOGPROBS = -1
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
|
||||
|
||||
class SamplingStates:
|
||||
def __init__(self, max_num_reqs: int, vocab_size: int):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
|
||||
|
||||
# Initialize top_k and top_p manually because 0 is an invalid value for them.
|
||||
self.top_k.np.fill(self.vocab_size)
|
||||
self.top_k.copy_to_uva()
|
||||
self.top_p.np.fill(1.0)
|
||||
self.top_p.copy_to_uva()
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(NO_LOGPROBS)
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
top_k = sampling_params.top_k
|
||||
if top_k <= 0 or top_k > self.vocab_size:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
self.min_p.np[req_idx] = sampling_params.min_p
|
||||
|
||||
seed = sampling_params.seed
|
||||
if seed is None:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = NO_LOGPROBS
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.temperature.copy_to_uva()
|
||||
self.top_p.copy_to_uva()
|
||||
self.top_k.copy_to_uva()
|
||||
self.min_p.copy_to_uva()
|
||||
self.seeds.copy_to_uva()
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
temp_np = self.temperature.np[idx_mapping_np]
|
||||
if np.all((temp_np == 0.0) | (temp_np == 1.0)):
|
||||
# No request requires temperature. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_temperature(logits, idx_mapping, self.temperature.gpu)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
|
||||
# No request uses min_p. Skip the kernel launch.
|
||||
return
|
||||
apply_min_p(logits, idx_mapping, self.min_p.gpu)
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
|
||||
if not (do_top_k or do_top_p):
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
return int(np.max(self.num_logprobs[idx_mapping_np]))
|
||||
Reference in New Issue
Block a user