From 7ddf8e83d2d9f1d423bc3cae999ef9366fd0959c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 16 Jun 2025 05:47:51 -0700 Subject: [PATCH] [EAGLE] Fix draft kv cache layout for fa3 and topk > 1 (#7239) --- .../srt/layers/attention/flashattention_backend.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 97eead3af..6cbca78e9 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -406,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend): dtype=torch.int32, device=device, ) + # shape: [bs, num_steps, topk] -> [bs x topk, num_steps] cache_loc = forward_batch.out_cache_loc.view( - self.speculative_num_steps, -1 - ).T.contiguous() + -1, self.speculative_num_steps + ) metadata_expand.page_table = ( cache_loc[:, :decode_length].contiguous().to(torch.int32) ) @@ -1636,9 +1637,8 @@ class FlashAttentionBackend(AttentionBackend): # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) metadata_expand = self.draft_decode_metadata_topk_expand[bs] decode_length = self.speculative_step_id + 1 - cache_loc = out_cache_loc.view( - self.speculative_num_steps, -1 - ).T.contiguous() + # shape: [bs, num_steps, topk] -> [bs x topk, num_steps] + cache_loc = out_cache_loc.view(-1, self.speculative_num_steps) metadata_expand.page_table[: cache_loc.shape[0]].copy_( cache_loc[:, :decode_length] )