Minor: improve sampler & remove unused fields from model_config.py (#11531)

This commit is contained in:
Lianmin Zheng
2025-10-13 11:04:44 -07:00
committed by GitHub
parent 728af88781
commit 5e3f7e7fa9
5 changed files with 23 additions and 9 deletions

View File

@@ -92,6 +92,12 @@ class Sampler(nn.Module):
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
can_sample_directly_from_probs = (
not sampling_info.need_top_p_sampling
and not sampling_info.need_top_k_sampling
and not sampling_info.need_min_p_sampling
)
# If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
@@ -102,7 +108,14 @@ class Sampler(nn.Module):
probs = logits
del logits
if True: # Keep this redundant check to simplify some internal code sync
if can_sample_directly_from_probs:
# when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
batch_next_token_ids = sampling_from_probs_torch(
probs,
sampling_seed=sampling_info.sampling_seed,
positions=positions,
)
else:
if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)