[EAGLE] Fix draft kv cache layout for fa3 and topk > 1 (#7239)
This commit is contained in:
@@ -406,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
||||||
cache_loc = forward_batch.out_cache_loc.view(
|
cache_loc = forward_batch.out_cache_loc.view(
|
||||||
self.speculative_num_steps, -1
|
-1, self.speculative_num_steps
|
||||||
).T.contiguous()
|
)
|
||||||
metadata_expand.page_table = (
|
metadata_expand.page_table = (
|
||||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||||
)
|
)
|
||||||
@@ -1636,9 +1637,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
||||||
decode_length = self.speculative_step_id + 1
|
decode_length = self.speculative_step_id + 1
|
||||||
cache_loc = out_cache_loc.view(
|
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
||||||
self.speculative_num_steps, -1
|
cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
|
||||||
).T.contiguous()
|
|
||||||
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||||
cache_loc[:, :decode_length]
|
cache_loc[:, :decode_length]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user