Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -82,7 +82,7 @@ class PenaltiesState:
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -94,7 +94,7 @@ class PenaltiesState:
apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu,
@@ -110,7 +110,7 @@ class PenaltiesState:
def _penalties_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr,
@@ -125,7 +125,7 @@ def _penalties_kernel(
MAX_SPEC_LEN: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
@@ -191,7 +191,7 @@ def _penalties_kernel(
def apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
@@ -207,7 +207,7 @@ def apply_penalties(
_penalties_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
@@ -225,7 +225,7 @@ def apply_penalties(
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
@@ -236,9 +236,9 @@ def _bincount_kernel(
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
token_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
@@ -276,7 +276,7 @@ def _bincount_kernel(
def bincount(
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
@@ -284,13 +284,13 @@ def bincount(
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
prompt_bin_mask[expanded_idx_mapping] = 0
output_bin_counts[expanded_idx_mapping] = 0
num_tokens = expanded_idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
_bincount_kernel[(num_tokens, num_blocks)](
expanded_idx_mapping,
all_token_ids,
all_token_ids.stride(0),
prompt_len,