Tiny fix sampler error when prob is not contiguous (#6639)

This commit is contained in:
fzyzcjy
2025-05-27 10:19:08 +08:00
committed by GitHub
parent eb8f02dd87
commit ca95556c76

View File

@@ -101,7 +101,7 @@ class Sampler(nn.Module):
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan = self.use_nan_detection and crash_on_warnings()
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
probs.contiguous(),
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",