Fuse top_k and top_k in the sampler (#1457)
This commit is contained in:
@@ -23,6 +23,7 @@ class GenerateReqInput:
|
|||||||
# Whether to return logprobs.
|
# Whether to return logprobs.
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# The start location of the prompt for return_logprob.
|
# The start location of the prompt for return_logprob.
|
||||||
|
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# The number of top logprobs to return.
|
# The number of top logprobs to return.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
|
|||||||
@@ -31,8 +31,11 @@ class Sampler(nn.Module):
|
|||||||
logits = logits.next_token_logits
|
logits = logits.next_token_logits
|
||||||
|
|
||||||
# Post process logits
|
# Post process logits
|
||||||
|
logits = logits.contiguous()
|
||||||
logits.div_(sampling_info.temperatures)
|
logits.div_(sampling_info.temperatures)
|
||||||
probs = logits[:] = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
logits = None
|
||||||
|
del logits
|
||||||
|
|
||||||
if torch.any(torch.isnan(probs)):
|
if torch.any(torch.isnan(probs)):
|
||||||
logger.warning("Detected errors during sampling! NaN in the probability.")
|
logger.warning("Detected errors during sampling! NaN in the probability.")
|
||||||
@@ -53,7 +56,11 @@ class Sampler(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||||
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
sampling_info.top_ks,
|
||||||
|
sampling_info.top_ps,
|
||||||
|
filter_apply_order="joint",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.all(success):
|
if not torch.all(success):
|
||||||
|
|||||||
@@ -400,8 +400,8 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
max_num_reqs,
|
max_num_reqs + 1,
|
||||||
self.model_config.context_len + 8,
|
self.model_config.context_len + 4,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
|
|||||||
Reference in New Issue
Block a user