[Minor, Performance] Use torch.argmax for greedy sampling (#1589)

This commit is contained in:
Ying Sheng
2024-10-06 13:15:05 -07:00
committed by GitHub
parent 9c064bf78a
commit c98e84c21e
3 changed files with 34 additions and 2 deletions

View File

@@ -43,7 +43,10 @@ class Sampler(nn.Module):
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
)
if global_server_args_dict["sampling_backend"] == "flashinfer":
if sampling_info.top_ks.max().item() <= 1:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(probs, -1)
elif global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device