Fuse top_k and top_k in the sampler (#1457)

This commit is contained in:
Lianmin Zheng
2024-09-18 04:35:35 -07:00
committed by GitHub
parent 1acccb364a
commit 7f24ea95c3
3 changed files with 12 additions and 4 deletions

View File

@@ -31,8 +31,11 @@ class Sampler(nn.Module):
logits = logits.next_token_logits
# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
probs = logits[:] = torch.softmax(logits, dim=-1)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
if torch.any(torch.isnan(probs)):
logger.warning("Detected errors during sampling! NaN in the probability.")
@@ -53,7 +56,11 @@ class Sampler(nn.Module):
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if not torch.all(success):

View File

@@ -400,8 +400,8 @@ class ModelRunner:
)
self.req_to_token_pool = ReqToTokenPool(
max_num_reqs,
self.model_config.context_len + 8,
max_num_reqs + 1,
self.model_config.context_len + 4,
)
if (
self.model_config.attention_arch == AttentionArch.MLA