Support FA3 as Attention backend by using --attention-backend fa3 (#4680)
Co-authored-by: qsong <qsong@linkedin.com> Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
@@ -501,6 +501,7 @@ def get_dataset(args, tokenizer):
|
|||||||
question_len=args.gsp_question_len,
|
question_len=args.gsp_question_len,
|
||||||
output_len=args.gsp_output_len,
|
output_len=args.gsp_output_len,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
args=args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||||
@@ -788,6 +789,7 @@ def sample_generated_shared_prefix_requests(
|
|||||||
question_len: int,
|
question_len: int,
|
||||||
output_len: int,
|
output_len: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
args: argparse.Namespace,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
||||||
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
||||||
|
|||||||
295
python/sglang/srt/layers/attention/flashattention_backend.py
Normal file
295
python/sglang/srt/layers/attention/flashattention_backend.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
"""
|
||||||
|
Support different attention backends.
|
||||||
|
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
||||||
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
from flash_attn_interface import flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlashAttentionMetadata:
|
||||||
|
"""Metadata for decode operations to avoid redundant computations."""
|
||||||
|
|
||||||
|
cu_seqlens_q: torch.Tensor = None
|
||||||
|
cu_seqlens_k: torch.Tensor = None
|
||||||
|
max_seq_len_k: int = 0
|
||||||
|
window_size: tuple = (-1, -1)
|
||||||
|
page_table: torch.Tensor = None
|
||||||
|
cache_seqlens_int32: torch.Tensor = None
|
||||||
|
max_seq_len_q: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
"""FlashAttention backend implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
skip_prefill: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
model_runner.sliding_window_size is not None
|
||||||
|
and model_runner.model_config.is_encoder_decoder
|
||||||
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
|
||||||
|
# Initialize metadata
|
||||||
|
self.forward_metadata: FlashAttentionMetadata = None
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.device = model_runner.device
|
||||||
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Initialize forward metadata to cache repetitive calculations."""
|
||||||
|
# Create metadata based on forward mode
|
||||||
|
metadata = FlashAttentionMetadata()
|
||||||
|
|
||||||
|
extend_seq_lens = forward_batch.extend_seq_lens
|
||||||
|
# Get sequence information
|
||||||
|
seqlens_in_batch = forward_batch.seq_lens
|
||||||
|
# Precompute int32 version of sequence lengths
|
||||||
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
|
batch_size = len(seqlens_in_batch)
|
||||||
|
device = seqlens_in_batch.device
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
# Precompute maximum sequence length
|
||||||
|
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||||
|
# Precompute page table
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
if forward_batch.forward_mode == ForwardMode.DECODE:
|
||||||
|
# Precompute cumulative sequence lengths
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
|
||||||
|
# Precompute cumulative sequence lengths
|
||||||
|
if not extend_no_prefix:
|
||||||
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||||
|
metadata.max_seq_len_q = seqlens_in_batch.max().item()
|
||||||
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use precomputed metadata
|
||||||
|
metadata = self.forward_metadata
|
||||||
|
|
||||||
|
# # Use Flash Attention for prefill
|
||||||
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||||
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||||
|
# here is two side inclusive
|
||||||
|
window_size = (
|
||||||
|
(layer.sliding_window_size, 0)
|
||||||
|
if layer.sliding_window_size is not None
|
||||||
|
else (-1, -1)
|
||||||
|
)
|
||||||
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k_cache=key_cache.unsqueeze(1),
|
||||||
|
v_cache=value_cache.unsqueeze(1),
|
||||||
|
page_table=metadata.page_table,
|
||||||
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=metadata.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=layer.k_scale,
|
||||||
|
v_descale=layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with FlashAttention using precomputed metadata."""
|
||||||
|
# Save KV cache if needed
|
||||||
|
if k is not None and v is not None and save_kv_cache:
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get KV cache
|
||||||
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
|
||||||
|
# Use precomputed metadata
|
||||||
|
metadata = self.forward_metadata
|
||||||
|
|
||||||
|
# Pre-reshape query tensor
|
||||||
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||||
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||||
|
# here is two side inclusive
|
||||||
|
window_size = (
|
||||||
|
(layer.sliding_window_size, 0)
|
||||||
|
if layer.sliding_window_size is not None
|
||||||
|
else (-1, -1)
|
||||||
|
)
|
||||||
|
# Run attention with precomputed values
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q_reshaped,
|
||||||
|
k_cache=key_cache.unsqueeze(1),
|
||||||
|
v_cache=value_cache.unsqueeze(1),
|
||||||
|
page_table=metadata.page_table,
|
||||||
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=layer.k_scale,
|
||||||
|
v_descale=layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
|
"""Initialize CUDA graph state for the attention backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_bs (int): Maximum batch size to support in CUDA graphs
|
||||||
|
|
||||||
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
|
to avoid memory allocations.
|
||||||
|
"""
|
||||||
|
# Initialize fixed size tensors for decode operations
|
||||||
|
self.decode_cuda_graph_metadata = {
|
||||||
|
# Page table for token mapping (batch_size, max_context_len)
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs, self.max_context_len, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
):
|
||||||
|
"""Initialize forward metadata for capturing CUDA graph."""
|
||||||
|
metadata = FlashAttentionMetadata()
|
||||||
|
# Get sequence information
|
||||||
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
device = seq_lens.device
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
# Precompute maximum sequence length
|
||||||
|
metadata.max_seq_len_k = seq_lens.max().item()
|
||||||
|
# Precompute page table
|
||||||
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||||
|
req_pool_indices, :
|
||||||
|
]
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
# Precompute cumulative sequence lengths
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Do not support Prefill Mode cuda graph")
|
||||||
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
# """Initialize forward metadata for replaying CUDA graph."""
|
||||||
|
seqlens_in_batch = seq_lens[:bs]
|
||||||
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
# Precompute maximum sequence length
|
||||||
|
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||||
|
# Only zero out the part out of max_len_k
|
||||||
|
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
||||||
|
# Then do the copy
|
||||||
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
||||||
|
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
||||||
|
)
|
||||||
|
self.forward_decode_metadata = metadata
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
"""Get the fill value for sequence length in CUDA graph."""
|
||||||
|
return 0
|
||||||
@@ -868,6 +868,19 @@ class ModelRunner:
|
|||||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||||
|
|
||||||
self.attn_backend = FlashMLABackend(self)
|
self.attn_backend = FlashMLABackend(self)
|
||||||
|
elif self.server_args.attention_backend == "fa3":
|
||||||
|
assert torch.cuda.get_device_capability()[0] >= 9, (
|
||||||
|
"FlashAttention v3 Backend requires SM>=90. "
|
||||||
|
"Please use `--attention-backend flashinfer`."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn_backend = FlashAttentionBackend(self)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||||
|
|||||||
@@ -770,7 +770,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["flashinfer", "triton", "torch_native"],
|
choices=["flashinfer", "triton", "torch_native", "fa3"],
|
||||||
default=ServerArgs.attention_backend,
|
default=ServerArgs.attention_backend,
|
||||||
help="Choose the kernels for attention layers.",
|
help="Choose the kernels for attention layers.",
|
||||||
)
|
)
|
||||||
|
|||||||
311
python/sglang/test/attention/test_flashattn_backend.py
Normal file
311
python/sglang/test/attention/test_flashattn_backend.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
|
|
||||||
|
class MockModelRunner:
|
||||||
|
model_config = type(
|
||||||
|
"ModelConfig", (), {"context_len": 2048, "is_multimodal": False}
|
||||||
|
)
|
||||||
|
sliding_window_size = None
|
||||||
|
|
||||||
|
def __init__(self, device="cuda"):
|
||||||
|
self.device = device
|
||||||
|
# Create a proper req_to_token_pool with the req_to_token attribute
|
||||||
|
self.req_to_token_pool = type(
|
||||||
|
"TokenPool",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"size": 160, # a typical max_bs * max_context_len for cuda graph decode
|
||||||
|
"req_to_token": torch.zeros(
|
||||||
|
160, 2048, dtype=torch.int32, device=device
|
||||||
|
), # Add req_to_token attribute
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockReqToTokenPool:
|
||||||
|
def __init__(self, batch_size, seq_len, device):
|
||||||
|
self.req_to_token = (
|
||||||
|
torch.arange(batch_size * seq_len, device=device)
|
||||||
|
.reshape(batch_size, seq_len)
|
||||||
|
.to(torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||||
|
class TestFlashAttentionBackend(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures before each test method."""
|
||||||
|
self.model_runner = MockModelRunner()
|
||||||
|
self.backend = FlashAttentionBackend(self.model_runner)
|
||||||
|
|
||||||
|
# Common test parameters
|
||||||
|
self.batch_size = 2
|
||||||
|
self.seq_len = 4
|
||||||
|
self.num_heads = 2
|
||||||
|
self.head_dim = 8
|
||||||
|
self.device = "cuda"
|
||||||
|
self.dtype = torch.float16
|
||||||
|
|
||||||
|
def _create_attention_layer(self):
|
||||||
|
"""Helper method to create an attention layer."""
|
||||||
|
return RadixAttention(
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
scaling=1.0,
|
||||||
|
num_kv_heads=self.num_heads,
|
||||||
|
layer_id=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_kv_pool(self, size):
|
||||||
|
"""Helper method to create a KV pool."""
|
||||||
|
return MHATokenToKVPool(
|
||||||
|
size=size,
|
||||||
|
page_size=1, # only consider page=1 for unit test
|
||||||
|
dtype=self.dtype,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
layer_num=1, # only consider layer=1 for unit test
|
||||||
|
device=self.device,
|
||||||
|
enable_memory_saver=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_qkv_tensors(self, tokens_len):
|
||||||
|
"""Helper method to create q, k, v tensors."""
|
||||||
|
return (
|
||||||
|
torch.randn(
|
||||||
|
tokens_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
torch.randn(
|
||||||
|
tokens_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
torch.randn(
|
||||||
|
tokens_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _verify_output(self, output, expected_shape):
|
||||||
|
"""Helper method to verify output."""
|
||||||
|
self.assertEqual(
|
||||||
|
output.shape,
|
||||||
|
expected_shape,
|
||||||
|
f"Expected shape {expected_shape}, got {output.shape}",
|
||||||
|
)
|
||||||
|
self.assertEqual(output.dtype, self.dtype)
|
||||||
|
self.assertEqual(output.device.type, "cuda")
|
||||||
|
self.assertEqual(
|
||||||
|
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_extend(self):
|
||||||
|
"""Test the standard extend operation."""
|
||||||
|
# Create test inputs
|
||||||
|
q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len)
|
||||||
|
|
||||||
|
# Create attention layer
|
||||||
|
layer = self._create_attention_layer()
|
||||||
|
|
||||||
|
# Create forward batch
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
input_ids=torch.randint(
|
||||||
|
0, 100, (self.batch_size, self.seq_len), device=self.device
|
||||||
|
),
|
||||||
|
out_cache_loc=torch.arange(
|
||||||
|
self.batch_size * self.seq_len, device=self.device
|
||||||
|
),
|
||||||
|
seq_lens_sum=self.batch_size * self.seq_len,
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
|
seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device),
|
||||||
|
# 0 prefix, 4 extend
|
||||||
|
extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device),
|
||||||
|
extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device),
|
||||||
|
attn_backend=self.backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add token pool and KV cache
|
||||||
|
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
||||||
|
self.batch_size, self.seq_len, self.device
|
||||||
|
)
|
||||||
|
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
||||||
|
self.batch_size * self.seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize forward metadata before running the attention
|
||||||
|
self.backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
# Run forward_extend
|
||||||
|
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
expected_shape = (
|
||||||
|
self.batch_size * self.seq_len,
|
||||||
|
self.num_heads * self.head_dim,
|
||||||
|
)
|
||||||
|
self._verify_output(output, expected_shape)
|
||||||
|
|
||||||
|
def test_forward_decode(self):
|
||||||
|
"""Test the decode operation with cached tokens."""
|
||||||
|
# For decode, we only have one token per sequence
|
||||||
|
decode_len = 1
|
||||||
|
curr_seq_len = self.seq_len + decode_len
|
||||||
|
|
||||||
|
# Create test inputs
|
||||||
|
q, k, v = self._create_qkv_tensors(self.batch_size * decode_len)
|
||||||
|
|
||||||
|
# Create attention layer
|
||||||
|
layer = self._create_attention_layer()
|
||||||
|
|
||||||
|
# Create forward batch
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
input_ids=torch.randint(
|
||||||
|
0, 100, (self.batch_size, decode_len), device=self.device
|
||||||
|
),
|
||||||
|
out_cache_loc=torch.arange(
|
||||||
|
self.batch_size * self.seq_len,
|
||||||
|
self.batch_size * curr_seq_len,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
seq_lens_sum=self.batch_size * curr_seq_len,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
|
seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device),
|
||||||
|
attn_backend=self.backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add token pool and KV cache
|
||||||
|
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
||||||
|
self.batch_size, curr_seq_len, self.device
|
||||||
|
)
|
||||||
|
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
||||||
|
self.batch_size * curr_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-fill KV cache
|
||||||
|
cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len)
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
torch.arange(self.batch_size * self.seq_len, device=self.device),
|
||||||
|
cache_k,
|
||||||
|
cache_v,
|
||||||
|
layer.k_scale,
|
||||||
|
layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize forward metadata before running the attention
|
||||||
|
self.backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
# Run forward_decode
|
||||||
|
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
expected_shape = (self.batch_size, self.num_heads * self.head_dim)
|
||||||
|
self._verify_output(output, expected_shape)
|
||||||
|
|
||||||
|
def test_forward_extend_with_prefix(self):
|
||||||
|
"""Test extending from cached prefix tokens."""
|
||||||
|
# Define prefix and extend lengths
|
||||||
|
prefix_len = 2
|
||||||
|
extend_len = 2
|
||||||
|
total_len = prefix_len + extend_len
|
||||||
|
|
||||||
|
# Create test inputs for the extend portion
|
||||||
|
q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
|
||||||
|
|
||||||
|
# Create attention layer
|
||||||
|
layer = self._create_attention_layer()
|
||||||
|
|
||||||
|
# Create forward batch
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
input_ids=torch.randint(
|
||||||
|
0, 100, (self.batch_size, extend_len), device=self.device
|
||||||
|
),
|
||||||
|
out_cache_loc=torch.arange(
|
||||||
|
self.batch_size * prefix_len,
|
||||||
|
self.batch_size * total_len,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
seq_lens_sum=self.batch_size * total_len,
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
|
seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
|
||||||
|
extend_prefix_lens=torch.tensor(
|
||||||
|
[prefix_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
extend_seq_lens=torch.tensor(
|
||||||
|
[extend_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
attn_backend=self.backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add token pool and KV cache
|
||||||
|
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
||||||
|
self.batch_size, total_len, self.device
|
||||||
|
)
|
||||||
|
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
||||||
|
self.batch_size * total_len
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-fill the KV cache for prefix with known values
|
||||||
|
cache_k = torch.ones(
|
||||||
|
self.batch_size * prefix_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
cache_v = (
|
||||||
|
torch.ones(
|
||||||
|
self.batch_size * prefix_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
* 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the prefix KV cache
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
torch.arange(self.batch_size * prefix_len, device=self.device),
|
||||||
|
cache_k,
|
||||||
|
cache_v,
|
||||||
|
layer.k_scale,
|
||||||
|
layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize forward metadata before running the attention
|
||||||
|
self.backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
# Run forward_extend
|
||||||
|
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim)
|
||||||
|
self._verify_output(output, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user