diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 97eead3af..6cbca78e9 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -406,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend): dtype=torch.int32, device=device, ) + # shape: [bs, num_steps, topk] -> [bs x topk, num_steps] cache_loc = forward_batch.out_cache_loc.view( - self.speculative_num_steps, -1 - ).T.contiguous() + -1, self.speculative_num_steps + ) metadata_expand.page_table = ( cache_loc[:, :decode_length].contiguous().to(torch.int32) ) @@ -1636,9 +1637,8 @@ class FlashAttentionBackend(AttentionBackend): # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) metadata_expand = self.draft_decode_metadata_topk_expand[bs] decode_length = self.speculative_step_id + 1 - cache_loc = out_cache_loc.view( - self.speculative_num_steps, -1 - ).T.contiguous() + # shape: [bs, num_steps, topk] -> [bs x topk, num_steps] + cache_loc = out_cache_loc.view(-1, self.speculative_num_steps) metadata_expand.page_table[: cache_loc.shape[0]].copy_( cache_loc[:, :decode_length] )