FA3 Spec Decoding to support top k = 1 and add cuda graph support (#5050)
Co-authored-by: Qingquan Song <ustcsqq@gmail.com> Co-authored-by: Chunan Zeng <zcnrex@gmail.com>
This commit is contained in:
@@ -27,19 +27,42 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashAttentionMetadata:
|
class FlashAttentionMetadata:
|
||||||
"""Metadata for decode operations to avoid redundant computations."""
|
"""Metadata to be init once in the model forward pass,
|
||||||
|
each layer's forward pass can reuse the metadata."""
|
||||||
|
|
||||||
|
# Cumulative sequence lengths for query
|
||||||
cu_seqlens_q: torch.Tensor = None
|
cu_seqlens_q: torch.Tensor = None
|
||||||
|
# Cumulative sequence lengths for key
|
||||||
cu_seqlens_k: torch.Tensor = None
|
cu_seqlens_k: torch.Tensor = None
|
||||||
|
# Maximum sequence length for query
|
||||||
max_seq_len_q: int = 0
|
max_seq_len_q: int = 0
|
||||||
|
# Maximum sequence length for key
|
||||||
max_seq_len_k: int = 0
|
max_seq_len_k: int = 0
|
||||||
|
# Window size (typically used by Gemma)
|
||||||
window_size: tuple = (-1, -1)
|
window_size: tuple = (-1, -1)
|
||||||
|
# Page table, the index of KV Cache Tables/Blocks
|
||||||
page_table: torch.Tensor = None
|
page_table: torch.Tensor = None
|
||||||
|
# Sequence lengths for the forward batch
|
||||||
cache_seqlens_int32: torch.Tensor = None
|
cache_seqlens_int32: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
"""FlashAttention backend implementation."""
|
"""FlashAttention backend implementation.
|
||||||
|
|
||||||
|
Note about the init:
|
||||||
|
- If no spec decoding
|
||||||
|
- FlashAttentionBackend will be init once when the server starts.
|
||||||
|
- If spec decoding
|
||||||
|
- FlashAttentionBackend will be init once for the target worker
|
||||||
|
- FlashAttentionMultiStepBackend will be once for the draft worker
|
||||||
|
- It will spawn num_steps FlashAttentionBackend for the draft worker
|
||||||
|
|
||||||
|
Note about CUDA Graph:
|
||||||
|
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
|
||||||
|
- We don't support CUDA Graph for Extend and Draft Extend.
|
||||||
|
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
|
||||||
|
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -56,41 +79,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
and model_runner.model_config.is_encoder_decoder
|
and model_runner.model_config.is_encoder_decoder
|
||||||
), "Sliding window and cross attention are not supported together"
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
|
||||||
# Initialize metadata
|
|
||||||
self.forward_metadata: FlashAttentionMetadata = None
|
self.forward_metadata: FlashAttentionMetadata = None
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
self.target_verify_metadata = {}
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
self.use_mla = (
|
self.use_mla = (
|
||||||
model_runner.model_config.attention_arch == AttentionArch.MLA
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
) and (not global_server_args_dict["disable_mla"])
|
) and (not global_server_args_dict["disable_mla"])
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
self.topk = topk
|
|
||||||
self.speculative_num_steps = speculative_num_steps
|
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
||||||
|
assert (
|
||||||
|
topk <= 1
|
||||||
|
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
||||||
|
|
||||||
|
self.topk = 1
|
||||||
self.step_id = step_id
|
self.step_id = step_id
|
||||||
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize forward metadata to cache repetitive calculations."""
|
"""Initialize forward metadata to cache repetitive calculations."""
|
||||||
# Create metadata based on forward mode
|
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
|
|
||||||
# Get sequence information
|
|
||||||
seqlens_in_batch = forward_batch.seq_lens
|
seqlens_in_batch = forward_batch.seq_lens
|
||||||
# Precompute int32 version of sequence lengths
|
|
||||||
batch_size = len(seqlens_in_batch)
|
batch_size = len(seqlens_in_batch)
|
||||||
device = seqlens_in_batch.device
|
device = seqlens_in_batch.device
|
||||||
|
if forward_batch.forward_mode.is_decode():
|
||||||
if forward_batch.forward_mode == ForwardMode.DECODE:
|
# Skip Prefill or Draft Decode
|
||||||
if self.skip_prefill:
|
# Note: Draft Decode will be ran on the Draft Worker
|
||||||
|
if forward_batch.spec_info is not None:
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
0, batch_size * self.topk + 1, dtype=torch.int32, device=device
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
||||||
metadata.cache_seqlens_int32 = (
|
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
||||||
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
|
|
||||||
)
|
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
@@ -103,86 +127,58 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
metadata.page_table = metadata.page_table.repeat_interleave(
|
|
||||||
self.topk, dim=0
|
|
||||||
)
|
|
||||||
cache_loc = forward_batch.out_cache_loc.view(
|
cache_loc = forward_batch.out_cache_loc.view(
|
||||||
self.speculative_num_steps, -1
|
self.speculative_num_steps, -1
|
||||||
).T
|
).T
|
||||||
# Calculate page table indices and cache location indices to update the page table.
|
|
||||||
batch_indices = torch.arange(
|
|
||||||
batch_size, device=device
|
|
||||||
).repeat_interleave(self.topk * (self.step_id + 1))
|
|
||||||
topk_indices = torch.arange(self.topk, device=device).repeat(
|
|
||||||
batch_size * (self.step_id + 1)
|
|
||||||
)
|
|
||||||
row_indices = batch_indices * self.topk + topk_indices
|
|
||||||
|
|
||||||
page_table_col_base_indices = seqlens_in_batch.unsqueeze(
|
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
||||||
1
|
real_bsz_start_idx = idx
|
||||||
) + torch.arange(self.step_id + 1, device=device)
|
real_bsz_end_idx = idx + 1
|
||||||
page_table_col_indices = page_table_col_base_indices.view(-1).repeat(
|
metadata.page_table[
|
||||||
self.topk
|
real_bsz_start_idx:real_bsz_end_idx,
|
||||||
)
|
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
||||||
|
] = cache_loc[
|
||||||
cache_loc_col_indices = torch.arange(
|
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
||||||
self.step_id + 1, device=device, dtype=torch.int32
|
]
|
||||||
).repeat(batch_size * self.topk)
|
else: # Normal Decode without Spec Decoding
|
||||||
|
|
||||||
metadata.page_table[row_indices, page_table_col_indices] = cache_loc[
|
|
||||||
row_indices, cache_loc_col_indices
|
|
||||||
].to(torch.int32)
|
|
||||||
else:
|
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
)
|
)
|
||||||
# Precompute maximum sequence length
|
|
||||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
# Precompute page table
|
|
||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
0, batch_size + 1, dtype=torch.int32, device=device
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY:
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
# Note: Target Verify will be ran on the Target Worker
|
||||||
draft_token_num = forward_batch.spec_info.draft_token_num
|
draft_token_num = forward_batch.spec_info.draft_token_num
|
||||||
|
metadata.cache_seqlens_int32 = (
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
forward_batch.seq_lens + draft_token_num
|
||||||
0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device
|
).to(torch.int32)
|
||||||
|
metadata.max_seq_len_q = draft_token_num
|
||||||
|
metadata.max_seq_len_k = (
|
||||||
|
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
||||||
)
|
)
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
|
0,
|
||||||
metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave(
|
batch_size * draft_token_num + 1,
|
||||||
forward_batch.spec_info.draft_token_num
|
draft_token_num,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
||||||
(1, 0),
|
(1, 0),
|
||||||
)
|
)
|
||||||
metadata.max_seq_len_k = (
|
|
||||||
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
|
||||||
)
|
|
||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
].repeat_interleave(draft_token_num, dim=0)
|
]
|
||||||
aug_cum_len = torch.nn.functional.pad(
|
|
||||||
torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
|
||||||
)
|
|
||||||
for idx, single_seq_len in enumerate(aug_seq_lens):
|
|
||||||
metadata.page_table[
|
|
||||||
idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len
|
|
||||||
] *= forward_batch.spec_info.custom_mask[
|
|
||||||
aug_cum_len[idx]
|
|
||||||
* draft_token_num : aug_cum_len[idx + 1]
|
|
||||||
* draft_token_num
|
|
||||||
].view(
|
|
||||||
draft_token_num, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata.max_seq_len_q = 1
|
elif forward_batch.forward_mode.is_extend_or_draft_extend():
|
||||||
else:
|
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
@@ -208,7 +204,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||||
|
|
||||||
# Precompute strided indices
|
# Precompute strided indices
|
||||||
# [0, page_size, 2 * page_size, ...]
|
|
||||||
if self.page_size > 1:
|
if self.page_size > 1:
|
||||||
self.strided_indices = torch.arange(
|
self.strided_indices = torch.arange(
|
||||||
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
||||||
@@ -227,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
@@ -262,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
page_table = metadata.page_table
|
page_table = metadata.page_table
|
||||||
|
|
||||||
# # Use Flash Attention for prefill
|
# Use Flash Attention for prefill
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
# Do multi-head attention
|
# Do multi-head attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
@@ -368,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None
|
if layer.sliding_window_size is not None
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
page_table = metadata.page_table
|
page_table = metadata.page_table
|
||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
@@ -437,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale=layer.k_scale,
|
k_descale=layer.k_scale,
|
||||||
v_descale=layer.v_scale,
|
v_descale=layer.v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
@@ -449,11 +441,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
to avoid memory allocations.
|
to avoid memory allocations.
|
||||||
"""
|
"""
|
||||||
if self.speculative_num_steps > 0:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decode_cuda_graph_metadata = {
|
self.decode_cuda_graph_metadata = {
|
||||||
# Page table for token mapping (batch_size, max_context_len)
|
# Page table for token mapping (batch_size, max_context_len)
|
||||||
"page_table": torch.zeros(
|
"page_table": torch.zeros(
|
||||||
@@ -462,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
|
"page_table_draft_decode": torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"strided_indices": torch.arange(
|
||||||
|
0, self.max_context_len, self.page_size, device=self.device
|
||||||
|
),
|
||||||
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0, max_bs + 128, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.zeros(
|
||||||
|
max_bs + 128, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.target_verify_metadata = {
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||||
|
"cu_seqlens_q": torch.zeros(
|
||||||
|
max_bs + 128, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.zeros(
|
||||||
|
max_bs + 128, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"max_seqlen_q": 0,
|
||||||
"strided_indices": torch.arange(
|
"strided_indices": torch.arange(
|
||||||
0, self.max_context_len, self.page_size, device=self.device
|
0, self.max_context_len, self.page_size, device=self.device
|
||||||
),
|
),
|
||||||
@@ -479,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
"""Initialize forward metadata for capturing CUDA graph."""
|
"""Initialize forward metadata for capturing CUDA graph."""
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
# Get sequence information
|
|
||||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
|
||||||
batch_size = len(seq_lens)
|
|
||||||
device = seq_lens.device
|
device = seq_lens.device
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
if forward_mode.is_decode():
|
||||||
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
if spec_info is not None:
|
||||||
)
|
# Draft Decode
|
||||||
# Precompute maximum sequence length
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
metadata.max_seq_len_k = seq_lens.max().item()
|
0, bs + 1, dtype=torch.int32, device=device
|
||||||
# Precompute page table
|
)
|
||||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||||
req_pool_indices, :
|
"cache_seqlens"
|
||||||
]
|
][:bs]
|
||||||
if forward_mode.is_cuda_graph():
|
|
||||||
# Precompute cumulative sequence lengths
|
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
||||||
metadata.cu_seqlens_q = torch.arange(
|
: bs + 1
|
||||||
0, batch_size + 1, dtype=torch.int32, device=device
|
]
|
||||||
|
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
||||||
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
||||||
|
"page_table_draft_decode"
|
||||||
|
][req_pool_indices, :]
|
||||||
|
else:
|
||||||
|
# Normal Decode
|
||||||
|
# Get sequence information
|
||||||
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
device = seq_lens.device
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
# Precompute maximum sequence length
|
||||||
|
metadata.max_seq_len_k = seq_lens.max().item()
|
||||||
|
# Precompute page table
|
||||||
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||||
|
req_pool_indices, :
|
||||||
|
]
|
||||||
|
# Precompute cumulative sequence lengths
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
draft_token_num = spec_info.draft_token_num
|
||||||
|
|
||||||
|
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
||||||
|
:bs
|
||||||
|
]
|
||||||
|
metadata.cache_seqlens_int32.copy_(
|
||||||
|
(seq_lens + draft_token_num).to(torch.int32)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise ValueError("Do not support Prefill Mode cuda graph")
|
metadata.max_seq_len_q = draft_token_num
|
||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
|
||||||
|
|
||||||
|
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
|
||||||
|
torch.arange(
|
||||||
|
0,
|
||||||
|
bs * draft_token_num + 1,
|
||||||
|
draft_token_num,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
||||||
|
cu_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata.cu_seqlens_k = cu_k
|
||||||
|
metadata.page_table = self.target_verify_metadata["page_table"][
|
||||||
|
req_pool_indices, :
|
||||||
|
]
|
||||||
|
|
||||||
|
self.target_verify_metadata[bs] = metadata
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -512,28 +594,91 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
out_cache_loc: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
# """Initialize forward metadata for replaying CUDA graph."""
|
# """Initialize forward metadata for replaying CUDA graph."""
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
device = seq_lens.device
|
||||||
|
seq_lens = seq_lens[:bs]
|
||||||
|
req_pool_indices = req_pool_indices[:bs]
|
||||||
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
|
if forward_mode.is_decode():
|
||||||
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
|
||||||
# For CPU operations
|
if spec_info is not None:
|
||||||
max_len = seq_lens_cpu[:bs].max().item()
|
# Draft Decode
|
||||||
metadata.max_seq_len_k = max_len
|
max_len = seq_lens_cpu.max().item()
|
||||||
|
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
||||||
|
|
||||||
# For GPU operations
|
metadata.cache_seqlens_int32.copy_(
|
||||||
seq_lens_in_batch = seq_lens[:bs]
|
(seq_lens + (self.step_id + 1)).to(torch.int32)
|
||||||
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
|
)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
|
||||||
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
|
||||||
)
|
|
||||||
|
metadata.cu_seqlens_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
page_table = self.req_to_token[
|
||||||
|
req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
|
||||||
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
else:
|
||||||
|
# Normal Decode
|
||||||
|
max_len = seq_lens_cpu.max().item()
|
||||||
|
metadata.max_seq_len_k = max_len
|
||||||
|
|
||||||
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
max_seq_pages = (
|
||||||
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
|
) // self.page_size
|
||||||
|
page_indices = self.req_to_token[
|
||||||
|
:,
|
||||||
|
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
||||||
|
]
|
||||||
|
page_indices = page_indices[req_pool_indices] // self.page_size
|
||||||
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||||
|
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
metadata = self.target_verify_metadata[bs]
|
||||||
|
draft_token_num = spec_info.draft_token_num
|
||||||
|
|
||||||
|
metadata.cu_seqlens_q.copy_(
|
||||||
|
torch.arange(
|
||||||
|
0,
|
||||||
|
bs * draft_token_num + 1,
|
||||||
|
draft_token_num,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata.cache_seqlens_int32.copy_(
|
||||||
|
(seq_lens + draft_token_num).to(torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
|
||||||
|
metadata.cu_seqlens_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
||||||
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
|
||||||
page_indices = self.req_to_token[
|
|
||||||
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
|
||||||
]
|
|
||||||
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
|
||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
|
||||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
@@ -555,7 +700,6 @@ class FlashAttentionMultiStepBackend:
|
|||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
FlashAttentionBackend(
|
FlashAttentionBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
skip_prefill=True,
|
|
||||||
topk=self.topk,
|
topk=self.topk,
|
||||||
speculative_num_steps=self.speculative_num_steps,
|
speculative_num_steps=self.speculative_num_steps,
|
||||||
step_id=i,
|
step_id=i,
|
||||||
@@ -570,7 +714,10 @@ class FlashAttentionMultiStepBackend:
|
|||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
assert forward_batch.spec_info is not None
|
assert forward_batch.spec_info is not None
|
||||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
|
|
||||||
@@ -601,4 +748,5 @@ class FlashAttentionMultiStepBackend:
|
|||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||||
|
out_cache_loc=forward_batch.out_cache_loc,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ class ForwardMode(IntEnum):
|
|||||||
or self == ForwardMode.IDLE
|
or self == ForwardMode.IDLE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_extend_or_draft_extend(self):
|
||||||
|
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
|
||||||
|
|
||||||
def is_dummy_first(self):
|
def is_dummy_first(self):
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user