Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -72,7 +72,7 @@ class BadWordsState:
def apply_bad_words(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -84,7 +84,7 @@ class BadWordsState:
apply_bad_words(
logits,
idx_mapping,
expanded_idx_mapping,
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
@@ -114,17 +114,17 @@ def _bad_words_kernel(
input_ids_ptr,
expanded_local_pos_ptr,
):
logit_idx = tl.program_id(0)
token_idx = tl.program_id(0)
bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words:
return
pos = tl.load(expanded_local_pos_ptr + logit_idx)
cur_req_first_pos = logit_idx - pos
pos = tl.load(expanded_local_pos_ptr + token_idx)
cur_req_first_pos = token_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx)
@@ -159,7 +159,7 @@ def _bad_words_kernel(
match = match & (expected == actual)
if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words(
@@ -175,8 +175,8 @@ def apply_bad_words(
expanded_local_pos: torch.Tensor,
max_num_bad_words: int,
) -> None:
total_num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
num_tokens = logits.shape[0]
_bad_words_kernel[(num_tokens, max_num_bad_words)](
logits,
logits.stride(0),
expanded_idx_mapping,

View File

@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
temperature_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits.
@@ -25,24 +25,24 @@ def _temperature_kernel(
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_temperature(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)](
_temperature_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + batch_idx * logits_stride + block,
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32.
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [max_num_reqs]
logits: torch.Tensor, # [num_tokens, vocab_size]
expanded_idx_mapping: torch.Tensor, # [num_tokens]
temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_tokens]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_tokens,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_tokens,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
_gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
seed,
pos,
temperature,

View File

@@ -121,7 +121,7 @@ class LogitBiasState:
def apply_logit_bias(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
@@ -131,7 +131,7 @@ class LogitBiasState:
apply_logit_bias(
logits,
idx_mapping,
expanded_idx_mapping,
pos,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
@@ -149,7 +149,7 @@ def _bias_kernel(
logits_ptr,
logits_stride,
vocab_size,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
# Allowed token IDs.
num_allowed_token_ids_ptr,
allowed_token_ids_ptr,
@@ -169,8 +169,8 @@ def _bias_kernel(
BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block = tl.arange(0, BLOCK_SIZE)
@@ -186,21 +186,21 @@ def _bias_kernel(
mask=mask,
)
logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
)
# Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store(
logits_ptr + batch_idx * logits_stride + offset,
logits_ptr + token_idx * logits_stride + offset,
-float("inf"),
mask=offset < vocab_size,
)
# Restore logits for allowed token IDs.
tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
logits_ptr + token_idx * logits_stride + allowed_token_ids,
logits,
mask=mask,
)
@@ -214,13 +214,13 @@ def _bias_kernel(
mask=mask,
)
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids
@@ -229,7 +229,7 @@ def _bias_kernel(
mask=mask,
)
tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids,
logits_ptr + token_idx * logits_stride + stop_token_ids,
-float("inf"),
mask=mask,
)
@@ -237,7 +237,7 @@ def _bias_kernel(
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
@@ -248,7 +248,7 @@ def apply_logit_bias(
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
@@ -257,11 +257,11 @@ def apply_logit_bias(
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
_bias_kernel[(num_tokens,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
expanded_idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),

View File

@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _min_p_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0:
return
@@ -25,7 +25,9 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
@@ -35,21 +37,23 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
_min_p_kernel[(num_tokens,)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,

View File

@@ -82,7 +82,7 @@ class PenaltiesState:
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -94,7 +94,7 @@ class PenaltiesState:
apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu,
@@ -110,7 +110,7 @@ class PenaltiesState:
def _penalties_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr,
@@ -125,7 +125,7 @@ def _penalties_kernel(
MAX_SPEC_LEN: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
@@ -191,7 +191,7 @@ def _penalties_kernel(
def apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
@@ -207,7 +207,7 @@ def apply_penalties(
_penalties_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
@@ -225,7 +225,7 @@ def apply_penalties(
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
@@ -236,9 +236,9 @@ def _bincount_kernel(
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
token_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
@@ -276,7 +276,7 @@ def _bincount_kernel(
def bincount(
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
@@ -284,13 +284,13 @@ def bincount(
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
prompt_bin_mask[expanded_idx_mapping] = 0
output_bin_counts[expanded_idx_mapping] = 0
num_tokens = expanded_idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
_bincount_kernel[(num_tokens, num_blocks)](
expanded_idx_mapping,
all_token_ids,
all_token_ids.stride(0),
prompt_len,

View File

@@ -56,7 +56,7 @@ class Sampler:
def __call__(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
@@ -68,7 +68,7 @@ class Sampler:
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
pos,
input_ids,
@@ -101,7 +101,7 @@ class Sampler:
def sample(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
@@ -111,12 +111,14 @@ class Sampler:
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
self.logit_bias_state.apply_logit_bias(
logits, expanded_idx_mapping, idx_mapping_np, pos
)
# Apply penalties in place.
self.penalties_state.apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
@@ -126,27 +128,29 @@ class Sampler:
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
self.sampling_states.apply_temperature(
logits, expanded_idx_mapping, idx_mapping_np
)
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np
logits, expanded_idx_mapping, idx_mapping_np
)
# Sample the next token.
sampled = gumbel_sample(
logits,
idx_mapping,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,

View File

@@ -64,7 +64,7 @@ class SamplingStates:
def apply_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
@@ -72,23 +72,23 @@ class SamplingStates:
# No request requires temperature. Skip the kernel launch.
return
apply_temperature(logits, idx_mapping, self.temperature.gpu)
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
def apply_min_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
@@ -96,8 +96,8 @@ class SamplingStates:
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int: