Deprecate --disable-flashinfer and introduce --attention-backend (#1380)
This commit is contained in:
@@ -78,7 +78,7 @@ class Sampler(CustomOp):
|
||||
|
||||
probs = self._get_probs(logits, sampling_info)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
@@ -93,11 +93,15 @@ class Sampler(CustomOp):
|
||||
batch_next_token_ids, success = flashinfer_top_k_top_p(
|
||||
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
||||
)
|
||||
else:
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
return SampleOutput(success, probs, batch_next_token_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user