From 019851d099543852130e560c6160c572d5f70b09 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 10 Jun 2025 05:22:40 -0700 Subject: [PATCH] Fix eagle on AMD (#7051) --- python/sglang/srt/speculative/eagle_utils.py | 3 +++ test/srt/test_bench_serving.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 577e8009b..8bb1222da 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -123,6 +123,9 @@ class EagleDraftInput: cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + if paged_kernel_lens_sum is None: + paged_kernel_lens_sum = cum_kv_seq_len[-1] + kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 83e6b71a1..f7317b191 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -194,7 +194,7 @@ class TestBenchServing(CustomTestCase): self.assertLess(res["median_ttft_ms"], 150) # TODO: not set yet, need AMD machine else: - self.assertLess(res["median_ttft_ms"], 94) + self.assertLess(res["median_ttft_ms"], 98) self.assertLess(res["median_itl_ms"], 8) def test_online_latency_eagle(self):