Fix eagle on AMD (#7051)

This commit is contained in:
Lianmin Zheng
2025-06-10 05:22:40 -07:00
committed by GitHub
parent 2dae104dca
commit 019851d099
2 changed files with 4 additions and 1 deletions

View File

@@ -123,6 +123,9 @@ class EagleDraftInput:
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)