Add Eagle Speculative Decoding to FA3 Backend (#4951)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com> Co-authored-by: zcnrex <zcnrex@gmail.com>
This commit is contained in:
@@ -45,6 +45,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self,
|
self,
|
||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
skip_prefill: bool = False,
|
skip_prefill: bool = False,
|
||||||
|
topk=0,
|
||||||
|
speculative_num_steps=0,
|
||||||
|
step_id=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -63,6 +66,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.use_mla = (
|
self.use_mla = (
|
||||||
model_runner.model_config.attention_arch == AttentionArch.MLA
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
) and (not global_server_args_dict["disable_mla"])
|
) and (not global_server_args_dict["disable_mla"])
|
||||||
|
self.skip_prefill = skip_prefill
|
||||||
|
self.topk = topk
|
||||||
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
self.step_id = step_id
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize forward metadata to cache repetitive calculations."""
|
"""Initialize forward metadata to cache repetitive calculations."""
|
||||||
@@ -72,18 +79,133 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# Get sequence information
|
# Get sequence information
|
||||||
seqlens_in_batch = forward_batch.seq_lens
|
seqlens_in_batch = forward_batch.seq_lens
|
||||||
# Precompute int32 version of sequence lengths
|
# Precompute int32 version of sequence lengths
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
|
||||||
batch_size = len(seqlens_in_batch)
|
batch_size = len(seqlens_in_batch)
|
||||||
device = seqlens_in_batch.device
|
device = seqlens_in_batch.device
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
if forward_batch.forward_mode == ForwardMode.DECODE:
|
||||||
)
|
if self.skip_prefill:
|
||||||
# Precompute maximum sequence length
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
0, batch_size * self.topk + 1, dtype=torch.int32, device=device
|
||||||
# Precompute page table
|
)
|
||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
metadata.cache_seqlens_int32 = (
|
||||||
]
|
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
|
||||||
|
)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
||||||
|
self.step_id + 1
|
||||||
|
)
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
metadata.page_table = metadata.page_table.repeat_interleave(
|
||||||
|
self.topk, dim=0
|
||||||
|
)
|
||||||
|
cache_loc = forward_batch.out_cache_loc.view(
|
||||||
|
self.speculative_num_steps, -1
|
||||||
|
).T
|
||||||
|
# Calculate page table indices and cache location indices to update the page table.
|
||||||
|
batch_indices = torch.arange(
|
||||||
|
batch_size, device=device
|
||||||
|
).repeat_interleave(self.topk * (self.step_id + 1))
|
||||||
|
topk_indices = torch.arange(self.topk, device=device).repeat(
|
||||||
|
batch_size * (self.step_id + 1)
|
||||||
|
)
|
||||||
|
row_indices = batch_indices * self.topk + topk_indices
|
||||||
|
|
||||||
|
page_table_col_base_indices = seqlens_in_batch.unsqueeze(
|
||||||
|
1
|
||||||
|
) + torch.arange(self.step_id + 1, device=device)
|
||||||
|
page_table_col_indices = page_table_col_base_indices.view(-1).repeat(
|
||||||
|
self.topk
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_loc_col_indices = torch.arange(
|
||||||
|
self.step_id + 1, device=device, dtype=torch.int32
|
||||||
|
).repeat(batch_size * self.topk)
|
||||||
|
|
||||||
|
metadata.page_table[row_indices, page_table_col_indices] = cache_loc[
|
||||||
|
row_indices, cache_loc_col_indices
|
||||||
|
].to(torch.int32)
|
||||||
|
else:
|
||||||
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
# Precompute maximum sequence length
|
||||||
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
|
# Precompute page table
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY:
|
||||||
|
draft_token_num = forward_batch.spec_info.draft_token_num
|
||||||
|
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
|
||||||
|
metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave(
|
||||||
|
forward_batch.spec_info.draft_token_num
|
||||||
|
)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata.max_seq_len_k = (
|
||||||
|
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
||||||
|
)
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
].repeat_interleave(draft_token_num, dim=0)
|
||||||
|
aug_cum_len = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
for idx, single_seq_len in enumerate(aug_seq_lens):
|
||||||
|
metadata.page_table[
|
||||||
|
idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len
|
||||||
|
] *= forward_batch.spec_info.custom_mask[
|
||||||
|
aug_cum_len[idx]
|
||||||
|
* draft_token_num : aug_cum_len[idx + 1]
|
||||||
|
* draft_token_num
|
||||||
|
].view(
|
||||||
|
draft_token_num, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata.max_seq_len_q = 1
|
||||||
|
else:
|
||||||
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
# Precompute maximum sequence length
|
||||||
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
|
# Precompute page table
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
# Precompute cumulative sequence lengths
|
||||||
|
if (
|
||||||
|
any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
||||||
|
):
|
||||||
|
extend_seq_lens = forward_batch.extend_seq_lens
|
||||||
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||||
|
else:
|
||||||
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||||
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||||
|
|
||||||
# Precompute strided indices
|
# Precompute strided indices
|
||||||
# [0, page_size, 2 * page_size, ...]
|
# [0, page_size, 2 * page_size, ...]
|
||||||
@@ -94,23 +216,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = (
|
metadata.page_table = (
|
||||||
metadata.page_table[:, self.strided_indices] // self.page_size
|
metadata.page_table[:, self.strided_indices] // self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_batch.forward_mode == ForwardMode.DECODE:
|
|
||||||
# Precompute cumulative sequence lengths
|
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
|
||||||
0, batch_size + 1, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Precompute cumulative sequence lengths
|
|
||||||
if any(forward_batch.extend_prefix_lens_cpu):
|
|
||||||
extend_seq_lens = forward_batch.extend_seq_lens
|
|
||||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
|
||||||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
|
||||||
)
|
|
||||||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
|
||||||
else:
|
|
||||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
|
||||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
@@ -281,8 +386,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Pre-reshape query tensor
|
# Pre-reshape query tensor
|
||||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
# Run attention with precomputed values
|
|
||||||
o = flash_attn_with_kvcache(
|
o = flash_attn_with_kvcache(
|
||||||
q=q_reshaped,
|
q=q_reshaped,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
@@ -346,7 +449,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
to avoid memory allocations.
|
to avoid memory allocations.
|
||||||
"""
|
"""
|
||||||
# Initialize fixed size tensors for decode operations
|
if self.speculative_num_steps > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
|
||||||
|
)
|
||||||
|
|
||||||
self.decode_cuda_graph_metadata = {
|
self.decode_cuda_graph_metadata = {
|
||||||
# Page table for token mapping (batch_size, max_context_len)
|
# Page table for token mapping (batch_size, max_context_len)
|
||||||
"page_table": torch.zeros(
|
"page_table": torch.zeros(
|
||||||
@@ -385,7 +492,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||||
req_pool_indices, :
|
req_pool_indices, :
|
||||||
]
|
]
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode.is_cuda_graph():
|
||||||
# Precompute cumulative sequence lengths
|
# Precompute cumulative sequence lengths
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
0, batch_size + 1, dtype=torch.int32, device=device
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
@@ -432,3 +539,66 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
"""Get the fill value for sequence length in CUDA graph."""
|
"""Get the fill value for sequence length in CUDA graph."""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class FlashAttentionMultiStepBackend:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
||||||
|
):
|
||||||
|
self.model_runner = model_runner
|
||||||
|
self.topk = topk
|
||||||
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
|
||||||
|
self.attn_backends = []
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends.append(
|
||||||
|
FlashAttentionBackend(
|
||||||
|
model_runner,
|
||||||
|
skip_prefill=True,
|
||||||
|
topk=self.topk,
|
||||||
|
speculative_num_steps=self.speculative_num_steps,
|
||||||
|
step_id=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
for i in range(self.speculative_num_steps - 1):
|
||||||
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
|
assert forward_batch.spec_info is not None
|
||||||
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
|
|
||||||
|
for i in range(self.speculative_num_steps - 1):
|
||||||
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||||
|
forward_batch.batch_size,
|
||||||
|
forward_batch.batch_size * self.topk,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
encoder_lens=None,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, forward_batch: ForwardBatch, bs: int
|
||||||
|
):
|
||||||
|
assert forward_batch.spec_info is not None
|
||||||
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
|
|
||||||
|
for i in range(self.speculative_num_steps - 1):
|
||||||
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
|
bs,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
encoder_lens=None,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||||
|
)
|
||||||
|
|||||||
@@ -184,6 +184,19 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.draft_extend_attn_backend = None
|
self.draft_extend_attn_backend = None
|
||||||
self.padded_static_len = self.speculative_num_steps + 1
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
self.has_prefill_wrapper_verify = True
|
self.has_prefill_wrapper_verify = True
|
||||||
|
elif self.server_args.attention_backend == "fa3":
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionMultiStepBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.draft_attn_backend = FlashAttentionMultiStepBackend(
|
||||||
|
self.draft_model_runner,
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
|
self.draft_extend_attn_backend = None
|
||||||
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
|
self.has_prefill_wrapper_verify = False
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
||||||
|
|||||||
Reference in New Issue
Block a user