Refactor and Optimize FA3 Code (#5090)
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
This commit is contained in:
@@ -1,24 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
"""
|
||||
Support different attention backends.
|
||||
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -30,22 +22,25 @@ from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
"""Metadata to be init once in the model forward pass,
|
||||
each layer's forward pass can reuse the metadata."""
|
||||
each layer's forward pass can reuse the metadata.
|
||||
|
||||
# Cumulative sequence lengths for query
|
||||
cu_seqlens_q: torch.Tensor = None
|
||||
# Cumulative sequence lengths for key
|
||||
cu_seqlens_k: torch.Tensor = None
|
||||
For each init metadata function, we will try set up them in below order
|
||||
"""
|
||||
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
# Maximum sequence length for query
|
||||
max_seq_len_q: int = 0
|
||||
# Maximum sequence length for key
|
||||
max_seq_len_k: int = 0
|
||||
# Cumulative sequence lengths for query
|
||||
cu_seqlens_q: torch.Tensor = None
|
||||
# Cumulative sequence lengths for key
|
||||
cu_seqlens_k: torch.Tensor = None
|
||||
# Window size (typically used by Gemma)
|
||||
window_size: tuple = (-1, -1)
|
||||
# Page table, the index of KV Cache Tables/Blocks
|
||||
page_table: torch.Tensor = None
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
@@ -270,9 +265,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
speculative_step_id=0,
|
||||
topk=0,
|
||||
speculative_num_steps=0,
|
||||
step_id=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -293,14 +288,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
) and (not global_server_args_dict["disable_mla"])
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
||||
assert (
|
||||
topk <= 1
|
||||
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
||||
|
||||
self.topk = 1
|
||||
self.step_id = step_id
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.speculative_num_draft_tokens = (
|
||||
model_runner.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
self.speculative_step_id = speculative_step_id
|
||||
|
||||
# Local attention settings
|
||||
self.attention_chunk_size = (
|
||||
@@ -310,71 +303,59 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize forward metadata to cache repetitive calculations."""
|
||||
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
seqlens_in_batch = forward_batch.seq_lens
|
||||
batch_size = len(seqlens_in_batch)
|
||||
device = seqlens_in_batch.device
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
# Skip Prefill or Draft Decode
|
||||
# Note: Draft Decode will be ran on the Draft Worker
|
||||
# Draft Decode
|
||||
if forward_batch.spec_info is not None:
|
||||
metadata.cache_seqlens_int32 = (
|
||||
seqlens_in_batch + (self.speculative_step_id + 1)
|
||||
).to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
||||
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
||||
self.step_id + 1
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
cache_loc = forward_batch.out_cache_loc.view(
|
||||
self.speculative_num_steps, -1
|
||||
).T
|
||||
|
||||
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
||||
real_bsz_start_idx = idx
|
||||
real_bsz_end_idx = idx + 1
|
||||
metadata.page_table[
|
||||
real_bsz_start_idx:real_bsz_end_idx,
|
||||
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
||||
] = cache_loc[
|
||||
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
||||
]
|
||||
else: # Normal Decode without Spec Decoding
|
||||
else:
|
||||
# Normal Decode
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
# Note: Target Verify will be ran on the Target Worker
|
||||
draft_token_num = forward_batch.spec_info.draft_token_num
|
||||
metadata.cache_seqlens_int32 = (
|
||||
forward_batch.seq_lens + draft_token_num
|
||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||
).to(torch.int32)
|
||||
metadata.max_seq_len_q = draft_token_num
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
metadata.max_seq_len_k = (
|
||||
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
||||
forward_batch.seq_lens_cpu.max().item()
|
||||
+ self.speculative_num_draft_tokens
|
||||
)
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size * draft_token_num + 1,
|
||||
draft_token_num,
|
||||
batch_size * self.speculative_num_draft_tokens + 1,
|
||||
self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
@@ -387,31 +368,27 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
]
|
||||
|
||||
elif forward_batch.forward_mode.is_extend_or_draft_extend():
|
||||
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
# Precompute page table
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
# Precompute cumulative sequence lengths
|
||||
if (
|
||||
any(forward_batch.extend_prefix_lens_cpu)
|
||||
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
||||
):
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||
else:
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
|
||||
# Setup local attention if enabled
|
||||
if (
|
||||
@@ -458,7 +435,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
# Precompute strided indices
|
||||
# Convert the page table to a strided format which is needed by FA3 API
|
||||
if self.page_size > 1:
|
||||
self.strided_indices = torch.arange(
|
||||
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
||||
@@ -498,7 +475,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v,
|
||||
)
|
||||
|
||||
# Use precomputed metadata
|
||||
# Use precomputed metadata across all layers
|
||||
metadata = self.forward_metadata
|
||||
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
@@ -606,8 +583,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention using precomputed metadata."""
|
||||
# Save KV cache if needed
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
@@ -628,7 +603,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v,
|
||||
)
|
||||
|
||||
# Use precomputed metadata
|
||||
# Use precomputed metadata across all layers
|
||||
metadata = self.forward_metadata
|
||||
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
@@ -639,12 +614,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
page_table = metadata.page_table
|
||||
|
||||
if not self.use_mla:
|
||||
# Do multi-head attention
|
||||
|
||||
# Get KV cache
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
key_cache = key_cache.view(
|
||||
@@ -654,13 +626,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
)
|
||||
|
||||
# Pre-reshape query tensor
|
||||
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=page_table,
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
@@ -696,7 +667,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
@@ -719,7 +690,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
self.decode_cuda_graph_metadata = {
|
||||
# Page table for token mapping (batch_size, max_context_len)
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
@@ -735,30 +712,22 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0, max_bs + 128, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 128, dtype=torch.int32, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
self.target_verify_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.zeros(
|
||||
max_bs + 128, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 128, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"max_seqlen_q": 0,
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
@@ -780,24 +749,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if forward_mode.is_decode():
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, bs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||
"cache_seqlens"
|
||||
][:bs]
|
||||
|
||||
metadata.max_seq_len_k = seq_lens.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
||||
: bs + 1
|
||||
]
|
||||
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
||||
metadata.page_table = self.decode_cuda_graph_metadata[
|
||||
"page_table_draft_decode"
|
||||
][req_pool_indices, :]
|
||||
@@ -822,37 +788,30 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
elif forward_mode.is_target_verify():
|
||||
draft_token_num = spec_info.draft_token_num
|
||||
|
||||
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + draft_token_num).to(torch.int32)
|
||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_q = draft_token_num
|
||||
metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
metadata.max_seq_len_k = (
|
||||
seq_lens.max().item() + self.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
|
||||
torch.arange(
|
||||
0,
|
||||
bs * draft_token_num + 1,
|
||||
draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * self.speculative_num_draft_tokens + 1,
|
||||
self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
||||
: (bs + 1)
|
||||
]
|
||||
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
||||
cu_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
)
|
||||
metadata.cu_seqlens_k = cu_k
|
||||
|
||||
metadata.page_table = self.target_verify_metadata["page_table"][
|
||||
req_pool_indices, :
|
||||
]
|
||||
@@ -874,24 +833,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
out_cache_loc: torch.Tensor = None,
|
||||
):
|
||||
# """Initialize forward metadata for replaying CUDA graph."""
|
||||
device = seq_lens.device
|
||||
seq_lens = seq_lens[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
if forward_mode.is_decode():
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
||||
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + (self.step_id + 1)).to(torch.int32)
|
||||
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
@@ -929,22 +885,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
draft_token_num = spec_info.draft_token_num
|
||||
|
||||
metadata.cu_seqlens_q.copy_(
|
||||
torch.arange(
|
||||
0,
|
||||
bs * draft_token_num + 1,
|
||||
draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + draft_token_num).to(torch.int32)
|
||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
|
||||
metadata.max_seq_len_k = (
|
||||
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
||||
)
|
||||
metadata.cu_seqlens_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
@@ -972,14 +919,19 @@ class FlashAttentionMultiStepBackend:
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
||||
assert (
|
||||
self.topk == 1
|
||||
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
||||
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
FlashAttentionBackend(
|
||||
model_runner,
|
||||
speculative_step_id=i,
|
||||
topk=self.topk,
|
||||
speculative_num_steps=self.speculative_num_steps,
|
||||
step_id=i,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user