Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -82,7 +82,7 @@ class PenaltiesState:
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -94,7 +94,7 @@ class PenaltiesState:
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
@@ -110,7 +110,7 @@ class PenaltiesState:
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
@@ -125,7 +125,7 @@ def _penalties_kernel(
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
@@ -191,7 +191,7 @@ def _penalties_kernel(
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
@@ -207,7 +207,7 @@ def apply_penalties(
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
@@ -225,7 +225,7 @@ def apply_penalties(
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
@@ -236,9 +236,9 @@ def _bincount_kernel(
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
@@ -276,7 +276,7 @@ def _bincount_kernel(
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
@@ -284,13 +284,13 @@ def bincount(
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
prompt_bin_mask[expanded_idx_mapping] = 0
|
||||
output_bin_counts[expanded_idx_mapping] = 0
|
||||
num_tokens = expanded_idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
_bincount_kernel[(num_tokens, num_blocks)](
|
||||
expanded_idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
|
||||
Reference in New Issue
Block a user