Add dtype for more operations (#1705)

This commit is contained in:
Lianmin Zheng
2024-10-18 12:18:15 -07:00
committed by GitHub
parent 6d0fa73ece
commit 392f2863c8
3 changed files with 5 additions and 4 deletions

View File

@@ -57,7 +57,7 @@ class SamplingBatchInfo:
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float