Simplify pytorch sampling kernel and logit processor (#2491)
This commit is contained in:
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
|
||||
# Post process logits
|
||||
logits.div_(sampling_info.temperatures)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
logits = None
|
||||
del logits
|
||||
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
top_ks: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
min_ps: torch.Tensor,
|
||||
need_min_p_sampling: bool,
|
||||
):
|
||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||
>= top_ks.view(-1, 1)
|
||||
] = 0.0
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
|
||||
if need_min_p_sampling:
|
||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
# int32 range is enough to represent the token ids
|
||||
probs_idx = probs_idx.to(torch.int32)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def top_p_normalize_probs(
|
||||
probs: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
):
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
return top_p_renorm_prob(probs, top_ps)
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user