Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -72,7 +72,7 @@ class BadWordsState:
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -84,7 +84,7 @@ class BadWordsState:
|
||||
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
@@ -114,17 +114,17 @@ def _bad_words_kernel(
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
cur_req_first_pos = token_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
@@ -159,7 +159,7 @@ def _bad_words_kernel(
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
@@ -175,8 +175,8 @@ def apply_bad_words(
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
temperature_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
||||
if temperature == 0.0 or temperature == 1.0:
|
||||
# Early return to avoid loading logits.
|
||||
@@ -25,24 +25,24 @@ def _temperature_kernel(
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits / temperature
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_temperature_kernel[(num_reqs, num_blocks)](
|
||||
_temperature_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
expanded_idx_mapping: torch.Tensor, # [num_tokens]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_tokens]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
|
||||
@@ -121,7 +121,7 @@ class LogitBiasState:
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -131,7 +131,7 @@ class LogitBiasState:
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
@@ -149,7 +149,7 @@ def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
@@ -169,8 +169,8 @@ def _bias_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
@@ -186,21 +186,21 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
logits_ptr + token_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -214,13 +214,13 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
@@ -229,7 +229,7 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
@@ -237,7 +237,7 @@ def _bias_kernel(
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
@@ -248,7 +248,7 @@ def apply_logit_bias(
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
@@ -257,11 +257,11 @@ def apply_logit_bias(
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
_bias_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
@@ -25,7 +25,9 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
@@ -35,21 +37,23 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
_min_p_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
|
||||
@@ -82,7 +82,7 @@ class PenaltiesState:
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -94,7 +94,7 @@ class PenaltiesState:
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
@@ -110,7 +110,7 @@ class PenaltiesState:
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
@@ -125,7 +125,7 @@ def _penalties_kernel(
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
@@ -191,7 +191,7 @@ def _penalties_kernel(
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
@@ -207,7 +207,7 @@ def apply_penalties(
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
@@ -225,7 +225,7 @@ def apply_penalties(
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
@@ -236,9 +236,9 @@ def _bincount_kernel(
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
@@ -276,7 +276,7 @@ def _bincount_kernel(
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
@@ -284,13 +284,13 @@ def bincount(
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
prompt_bin_mask[expanded_idx_mapping] = 0
|
||||
output_bin_counts[expanded_idx_mapping] = 0
|
||||
num_tokens = expanded_idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
_bincount_kernel[(num_tokens, num_blocks)](
|
||||
expanded_idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
|
||||
@@ -56,7 +56,7 @@ class Sampler:
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
cu_num_logits_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
@@ -68,7 +68,7 @@ class Sampler:
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
pos,
|
||||
input_ids,
|
||||
@@ -101,7 +101,7 @@ class Sampler:
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -111,12 +111,14 @@ class Sampler:
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
self.logit_bias_state.apply_logit_bias(
|
||||
logits, expanded_idx_mapping, idx_mapping_np, pos
|
||||
)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
@@ -126,27 +128,29 @@ class Sampler:
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_temperature(
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(
|
||||
logits, idx_mapping, idx_mapping_np
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
|
||||
@@ -64,7 +64,7 @@ class SamplingStates:
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
temp_np = self.temperature.np[idx_mapping_np]
|
||||
@@ -72,23 +72,23 @@ class SamplingStates:
|
||||
# No request requires temperature. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_temperature(logits, idx_mapping, self.temperature.gpu)
|
||||
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
|
||||
# No request uses min_p. Skip the kernel launch.
|
||||
return
|
||||
apply_min_p(logits, idx_mapping, self.min_p.gpu)
|
||||
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
@@ -96,8 +96,8 @@ class SamplingStates:
|
||||
if not (do_top_k or do_top_p):
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
|
||||
Reference in New Issue
Block a user