Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -121,7 +121,7 @@ class LogitBiasState:
def apply_logit_bias(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
@@ -131,7 +131,7 @@ class LogitBiasState:
apply_logit_bias(
logits,
idx_mapping,
expanded_idx_mapping,
pos,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
@@ -149,7 +149,7 @@ def _bias_kernel(
logits_ptr,
logits_stride,
vocab_size,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
# Allowed token IDs.
num_allowed_token_ids_ptr,
allowed_token_ids_ptr,
@@ -169,8 +169,8 @@ def _bias_kernel(
BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block = tl.arange(0, BLOCK_SIZE)
@@ -186,21 +186,21 @@ def _bias_kernel(
mask=mask,
)
logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
)
# Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store(
logits_ptr + batch_idx * logits_stride + offset,
logits_ptr + token_idx * logits_stride + offset,
-float("inf"),
mask=offset < vocab_size,
)
# Restore logits for allowed token IDs.
tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
logits_ptr + token_idx * logits_stride + allowed_token_ids,
logits,
mask=mask,
)
@@ -214,13 +214,13 @@ def _bias_kernel(
mask=mask,
)
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids
@@ -229,7 +229,7 @@ def _bias_kernel(
mask=mask,
)
tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids,
logits_ptr + token_idx * logits_stride + stop_token_ids,
-float("inf"),
mask=mask,
)
@@ -237,7 +237,7 @@ def _bias_kernel(
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
@@ -248,7 +248,7 @@ def apply_logit_bias(
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
@@ -257,11 +257,11 @@ def apply_logit_bias(
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
_bias_kernel[(num_tokens,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
expanded_idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),