Tiny fix sampler error when prob is not contiguous (#6639)
This commit is contained in:
@@ -101,7 +101,7 @@ class Sampler(nn.Module):
|
|||||||
# Check Nan will throw exception, only check when crash_on_warnings is True
|
# Check Nan will throw exception, only check when crash_on_warnings is True
|
||||||
check_nan = self.use_nan_detection and crash_on_warnings()
|
check_nan = self.use_nan_detection and crash_on_warnings()
|
||||||
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||||
probs,
|
probs.contiguous(),
|
||||||
sampling_info.top_ks,
|
sampling_info.top_ks,
|
||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
filter_apply_order="joint",
|
filter_apply_order="joint",
|
||||||
|
|||||||
Reference in New Issue
Block a user