Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
@@ -25,7 +25,9 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
@@ -35,21 +37,23 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
_min_p_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
|
||||
Reference in New Issue
Block a user