Deprecate --disable-flashinfer and introduce --attention-backend (#1380)

This commit is contained in:
Lianmin Zheng
2024-09-10 17:11:16 -07:00
committed by GitHub
parent 3a6e8b6d78
commit 46094e0c1b
13 changed files with 99 additions and 61 deletions

View File

@@ -78,7 +78,7 @@ class Sampler(CustomOp):
probs = self._get_probs(logits, sampling_info)
if not global_server_args_dict["disable_flashinfer_sampling"]:
if global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
@@ -93,11 +93,15 @@ class Sampler(CustomOp):
batch_next_token_ids, success = flashinfer_top_k_top_p(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
)
else:
elif global_server_args_dict["sampling_backend"] == "pytorch":
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
return SampleOutput(success, probs, batch_next_token_ids)