Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 (#4760)
This commit is contained in:
@@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# Use Flash Attention for prefill
|
# Use Flash Attention for prefill
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
# Do multi-head attention
|
# Do multi-head attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
layer.layer_id
|
||||||
|
)
|
||||||
key_cache = key_cache.view(
|
key_cache = key_cache.view(
|
||||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||||
)
|
)
|
||||||
@@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
c_kv_cache = c_kv.view(
|
c_kv_cache = c_kv.view(
|
||||||
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
@@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
# Do multi-head attention
|
# Do multi-head attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||||
|
layer.layer_id
|
||||||
|
)
|
||||||
key_cache = key_cache.view(
|
key_cache = key_cache.view(
|
||||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
|||||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
||||||
ParallelLMHead,
|
|
||||||
UnquantizedEmbeddingMethod,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Base quantization methods that don't depend on vllm
|
# Base quantization methods that don't depend on vllm
|
||||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
@@ -176,6 +172,13 @@ def get_linear_quant_method(
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
linear_method_cls: type,
|
linear_method_cls: type,
|
||||||
):
|
):
|
||||||
|
# Move import here to avoid circular import. This is only used in monkey patching
|
||||||
|
# of vllm's QuantizationConfig.
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
UnquantizedEmbeddingMethod,
|
||||||
|
)
|
||||||
|
|
||||||
cloned_config = deepcopy(config)
|
cloned_config = deepcopy(config)
|
||||||
parallel_lm_head_quantized = (
|
parallel_lm_head_quantized = (
|
||||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||||
|
|||||||
@@ -2,60 +2,109 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import AttentionArch
|
||||||
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
||||||
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.model_executor.model_runner import ServerArgs
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
class MockModelRunner:
|
class MockModelRunner:
|
||||||
model_config = type(
|
def __init__(
|
||||||
"ModelConfig", (), {"context_len": 2048, "is_multimodal": False}
|
self,
|
||||||
|
page_size=1,
|
||||||
|
num_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
):
|
||||||
|
self.device = "cuda"
|
||||||
|
self.dtype = torch.float16
|
||||||
|
attention_arch = AttentionArch.MHA
|
||||||
|
# Max batch size for the test.
|
||||||
|
max_batch_size = 160
|
||||||
|
# Total tokens(prefix + extend + decode) in the test should not exceed this length.
|
||||||
|
max_context_len = 2048
|
||||||
|
self.model_config = type(
|
||||||
|
"ModelConfig",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"context_len": max_context_len,
|
||||||
|
"is_multimodal": False,
|
||||||
|
"attention_arch": attention_arch,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
sliding_window_size = None
|
self.sliding_window_size = None
|
||||||
|
self.device = self.device
|
||||||
def __init__(self, device="cuda"):
|
# Create a large enough req_to_token_pool to fit the test usage.
|
||||||
self.device = device
|
|
||||||
# Create a proper req_to_token_pool with the req_to_token attribute
|
|
||||||
self.req_to_token_pool = type(
|
self.req_to_token_pool = type(
|
||||||
"TokenPool",
|
"TokenPool",
|
||||||
(),
|
(),
|
||||||
{
|
{
|
||||||
"size": 160, # a typical max_bs * max_context_len for cuda graph decode
|
# A typical max_bs * max_context_len for cuda graph decode
|
||||||
|
"size": max_batch_size,
|
||||||
|
# Add req_to_token attribute
|
||||||
"req_to_token": torch.zeros(
|
"req_to_token": torch.zeros(
|
||||||
160, 2048, dtype=torch.int32, device=device
|
max_batch_size,
|
||||||
), # Add req_to_token attribute
|
max_context_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.page_size = page_size
|
||||||
|
max_total_num_tokens = max_batch_size * max_context_len
|
||||||
class MockReqToTokenPool:
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
def __init__(self, batch_size, seq_len, device):
|
size=max_total_num_tokens,
|
||||||
self.req_to_token = (
|
page_size=page_size,
|
||||||
torch.arange(batch_size * seq_len, device=device)
|
dtype=self.dtype,
|
||||||
.reshape(batch_size, seq_len)
|
head_num=num_heads,
|
||||||
.to(torch.int32)
|
head_dim=head_dim,
|
||||||
|
layer_num=1, # only consider layer=1 for unit test
|
||||||
|
device=self.device,
|
||||||
|
enable_memory_saver=False,
|
||||||
)
|
)
|
||||||
|
# Required by torch native backend
|
||||||
|
self.server_args = ServerArgs(model_path="fake_model_path")
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||||
class TestFlashAttentionBackend(CustomTestCase):
|
class TestFlashAttentionBackend(CustomTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test fixtures before each test method."""
|
# Test parameters
|
||||||
self.model_runner = MockModelRunner()
|
|
||||||
self.backend = FlashAttentionBackend(self.model_runner)
|
|
||||||
|
|
||||||
# Common test parameters
|
|
||||||
self.batch_size = 2
|
self.batch_size = 2
|
||||||
self.seq_len = 4
|
self.seq_len = 256
|
||||||
self.num_heads = 2
|
self.num_heads = 2
|
||||||
self.head_dim = 8
|
self.head_dim = 8
|
||||||
self.device = "cuda"
|
self.device = "cuda"
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
|
||||||
|
def _init_model_runner(self, page_size=1):
|
||||||
|
self.model_runner = MockModelRunner(
|
||||||
|
page_size=page_size,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
)
|
||||||
|
self.backend = FlashAttentionBackend(self.model_runner)
|
||||||
|
self.ref_backend = TorchNativeAttnBackend(self.model_runner)
|
||||||
|
self.model_runner.model_config.num_attention_heads = self.num_heads
|
||||||
|
|
||||||
|
def _mock_write_to_req_to_token_pool(self, batch_size, seq_len, page_size):
|
||||||
|
# if page_size > 1, the token pool stores the index to the page.
|
||||||
|
# so we need to multiply the index by page_size.
|
||||||
|
self.req_to_token = (
|
||||||
|
torch.arange(0, batch_size, dtype=torch.int32, device=self.device)[:, None]
|
||||||
|
* seq_len
|
||||||
|
+ torch.arange(0, seq_len, dtype=torch.int32, device=self.device)[None, :]
|
||||||
|
+ page_size
|
||||||
|
)
|
||||||
|
self.model_runner.req_to_token_pool.req_to_token[:batch_size, :seq_len] = (
|
||||||
|
self.req_to_token
|
||||||
|
)
|
||||||
|
|
||||||
def _create_attention_layer(self):
|
def _create_attention_layer(self):
|
||||||
"""Helper method to create an attention layer."""
|
"""Create attention layer for testing."""
|
||||||
return RadixAttention(
|
return RadixAttention(
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
head_dim=self.head_dim,
|
head_dim=self.head_dim,
|
||||||
@@ -64,47 +113,27 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|||||||
layer_id=0,
|
layer_id=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_kv_pool(self, size):
|
|
||||||
"""Helper method to create a KV pool."""
|
|
||||||
return MHATokenToKVPool(
|
|
||||||
size=size,
|
|
||||||
page_size=1, # only consider page=1 for unit test
|
|
||||||
dtype=self.dtype,
|
|
||||||
head_num=self.num_heads,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
layer_num=1, # only consider layer=1 for unit test
|
|
||||||
device=self.device,
|
|
||||||
enable_memory_saver=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_qkv_tensors(self, tokens_len):
|
def _create_qkv_tensors(self, tokens_len):
|
||||||
"""Helper method to create q, k, v tensors."""
|
"""Create q, k, v tensors for testing."""
|
||||||
|
shape = (tokens_len, self.num_heads, self.head_dim)
|
||||||
return (
|
return (
|
||||||
torch.randn(
|
torch.randn(shape, dtype=self.dtype, device=self.device),
|
||||||
tokens_len,
|
torch.randn(shape, dtype=self.dtype, device=self.device),
|
||||||
self.num_heads,
|
torch.randn(shape, dtype=self.dtype, device=self.device),
|
||||||
self.head_dim,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.randn(
|
|
||||||
tokens_len,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.randn(
|
|
||||||
tokens_len,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _verify_output(self, output, expected_shape):
|
def _run_reference_forward(
|
||||||
"""Helper method to verify output."""
|
self, mode, q, k, v, layer, forward_batch, expected_shape
|
||||||
|
):
|
||||||
|
"""Run reference forward pass using native backend."""
|
||||||
|
if mode == ForwardMode.EXTEND:
|
||||||
|
output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
else: # ForwardMode.DECODE
|
||||||
|
output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
|
||||||
|
return output.view(expected_shape)
|
||||||
|
|
||||||
|
def _verify_output(self, output, expected_shape, output_ref=None):
|
||||||
|
"""Verify output tensor shape, dtype, and values."""
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
output.shape,
|
output.shape,
|
||||||
expected_shape,
|
expected_shape,
|
||||||
@@ -116,161 +145,110 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|||||||
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_forward_extend(self):
|
if output_ref is not None:
|
||||||
"""Test the standard extend operation."""
|
if not torch.allclose(output, output_ref, atol=1e-1, rtol=0.0):
|
||||||
# Create test inputs
|
# Check where the values differ beyond the given tolerances
|
||||||
q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len)
|
diff_mask = ~torch.isclose(output, output_ref, atol=1e-1, rtol=0.0)
|
||||||
|
|
||||||
# Create attention layer
|
# Find the first index where the difference occurs
|
||||||
layer = self._create_attention_layer()
|
if diff_mask.any():
|
||||||
|
first_mismatch_idx = diff_mask.nonzero()[0]
|
||||||
|
print(
|
||||||
|
"First mismatch at index:", tuple(first_mismatch_idx.tolist())
|
||||||
|
)
|
||||||
|
print("output:", output[tuple(first_mismatch_idx.tolist())])
|
||||||
|
print("output_ref:", output_ref[tuple(first_mismatch_idx.tolist())])
|
||||||
|
raise AssertionError(
|
||||||
|
"Attention output is not close to the torch native backend output"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_forward_batch(self, mode, q_len=None, prefix_len=0, page_size=1):
|
||||||
|
"""Create a forward batch for testing based on mode and lengths."""
|
||||||
|
self._init_model_runner(page_size=page_size)
|
||||||
|
|
||||||
|
# Default to self.seq_len if not specified
|
||||||
|
q_len = q_len or self.seq_len
|
||||||
|
|
||||||
|
if mode == ForwardMode.EXTEND:
|
||||||
|
total_len = prefix_len + q_len
|
||||||
|
out_cache_start = prefix_len * self.batch_size
|
||||||
|
out_cache_end = total_len * self.batch_size
|
||||||
|
|
||||||
# Create forward batch
|
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
input_ids=torch.randint(
|
input_ids=torch.randint(
|
||||||
0, 100, (self.batch_size, self.seq_len), device=self.device
|
0, 100, (self.batch_size, q_len), device=self.device
|
||||||
),
|
),
|
||||||
out_cache_loc=torch.arange(
|
out_cache_loc=torch.arange(
|
||||||
self.batch_size * self.seq_len, device=self.device
|
out_cache_start, out_cache_end, device=self.device
|
||||||
),
|
),
|
||||||
seq_lens_sum=self.batch_size * self.seq_len,
|
seq_lens_sum=self.batch_size * total_len,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=mode,
|
||||||
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device),
|
seq_lens=torch.tensor(
|
||||||
# 0 prefix, 4 extend
|
[total_len] * self.batch_size, device=self.device
|
||||||
extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device),
|
),
|
||||||
extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device),
|
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
||||||
|
extend_prefix_lens=torch.tensor(
|
||||||
|
[prefix_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
extend_prefix_lens_cpu=torch.tensor(
|
||||||
|
[prefix_len] * self.batch_size, device="cpu"
|
||||||
|
),
|
||||||
|
extend_seq_lens=torch.tensor(
|
||||||
|
[q_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
extend_seq_lens_cpu=torch.tensor(
|
||||||
|
[q_len] * self.batch_size, device="cpu"
|
||||||
|
),
|
||||||
attn_backend=self.backend,
|
attn_backend=self.backend,
|
||||||
)
|
)
|
||||||
|
else: # ForwardMode.DECODE
|
||||||
|
decode_len = q_len # Assuming 1 for decode testing
|
||||||
|
total_len = self.seq_len + decode_len
|
||||||
|
if mode == ForwardMode.DECODE and page_size > 1:
|
||||||
|
# Get next page_size multiple of self.seq_len
|
||||||
|
out_cache_start = (
|
||||||
|
self.batch_size * self.seq_len // page_size + 1
|
||||||
|
) * page_size
|
||||||
|
# out_cache_end is the start of the next block
|
||||||
|
out_cache_end = out_cache_start + decode_len * page_size
|
||||||
|
else:
|
||||||
|
out_cache_start = self.batch_size * self.seq_len
|
||||||
|
out_cache_end = self.batch_size * total_len
|
||||||
|
|
||||||
# Add token pool and KV cache
|
|
||||||
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
|
||||||
self.batch_size, self.seq_len, self.device
|
|
||||||
)
|
|
||||||
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
|
||||||
self.batch_size * self.seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize forward metadata before running the attention
|
|
||||||
self.backend.init_forward_metadata(forward_batch)
|
|
||||||
|
|
||||||
# Run forward_extend
|
|
||||||
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
|
||||||
|
|
||||||
# Verify output
|
|
||||||
expected_shape = (
|
|
||||||
self.batch_size * self.seq_len,
|
|
||||||
self.num_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
self._verify_output(output, expected_shape)
|
|
||||||
|
|
||||||
def test_forward_decode(self):
|
|
||||||
"""Test the decode operation with cached tokens."""
|
|
||||||
# For decode, we only have one token per sequence
|
|
||||||
decode_len = 1
|
|
||||||
curr_seq_len = self.seq_len + decode_len
|
|
||||||
|
|
||||||
# Create test inputs
|
|
||||||
q, k, v = self._create_qkv_tensors(self.batch_size * decode_len)
|
|
||||||
|
|
||||||
# Create attention layer
|
|
||||||
layer = self._create_attention_layer()
|
|
||||||
|
|
||||||
# Create forward batch
|
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
input_ids=torch.randint(
|
input_ids=torch.randint(
|
||||||
0, 100, (self.batch_size, decode_len), device=self.device
|
0, 100, (self.batch_size, decode_len), device=self.device
|
||||||
),
|
),
|
||||||
out_cache_loc=torch.arange(
|
out_cache_loc=torch.tensor(
|
||||||
self.batch_size * self.seq_len,
|
[out_cache_start, out_cache_end], device=self.device
|
||||||
self.batch_size * curr_seq_len,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
seq_lens_sum=self.batch_size * curr_seq_len,
|
|
||||||
forward_mode=ForwardMode.DECODE,
|
|
||||||
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
|
||||||
seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device),
|
|
||||||
attn_backend=self.backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add token pool and KV cache
|
|
||||||
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
|
||||||
self.batch_size, curr_seq_len, self.device
|
|
||||||
)
|
|
||||||
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
|
||||||
self.batch_size * curr_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-fill KV cache
|
|
||||||
cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len)
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer,
|
|
||||||
torch.arange(self.batch_size * self.seq_len, device=self.device),
|
|
||||||
cache_k,
|
|
||||||
cache_v,
|
|
||||||
layer.k_scale,
|
|
||||||
layer.v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize forward metadata before running the attention
|
|
||||||
self.backend.init_forward_metadata(forward_batch)
|
|
||||||
|
|
||||||
# Run forward_decode
|
|
||||||
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
|
||||||
|
|
||||||
# Verify output
|
|
||||||
expected_shape = (self.batch_size, self.num_heads * self.head_dim)
|
|
||||||
self._verify_output(output, expected_shape)
|
|
||||||
|
|
||||||
def test_forward_extend_with_prefix(self):
|
|
||||||
"""Test extending from cached prefix tokens."""
|
|
||||||
# Define prefix and extend lengths
|
|
||||||
prefix_len = 2
|
|
||||||
extend_len = 2
|
|
||||||
total_len = prefix_len + extend_len
|
|
||||||
|
|
||||||
# Create test inputs for the extend portion
|
|
||||||
q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
|
|
||||||
|
|
||||||
# Create attention layer
|
|
||||||
layer = self._create_attention_layer()
|
|
||||||
|
|
||||||
# Create forward batch
|
|
||||||
forward_batch = ForwardBatch(
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
input_ids=torch.randint(
|
|
||||||
0, 100, (self.batch_size, extend_len), device=self.device
|
|
||||||
),
|
|
||||||
out_cache_loc=torch.arange(
|
|
||||||
self.batch_size * prefix_len,
|
|
||||||
self.batch_size * total_len,
|
|
||||||
device=self.device,
|
|
||||||
),
|
),
|
||||||
seq_lens_sum=self.batch_size * total_len,
|
seq_lens_sum=self.batch_size * total_len,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=mode,
|
||||||
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
|
seq_lens=torch.tensor(
|
||||||
extend_prefix_lens=torch.tensor(
|
[total_len] * self.batch_size, device=self.device
|
||||||
[prefix_len] * self.batch_size, device=self.device
|
|
||||||
),
|
|
||||||
extend_seq_lens=torch.tensor(
|
|
||||||
[extend_len] * self.batch_size, device=self.device
|
|
||||||
),
|
),
|
||||||
|
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
||||||
attn_backend=self.backend,
|
attn_backend=self.backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add token pool and KV cache
|
# Add token pool
|
||||||
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
||||||
self.batch_size, total_len, self.device
|
|
||||||
)
|
|
||||||
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
|
||||||
self.batch_size * total_len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-fill the KV cache for prefix with known values
|
# Write current batch's req_to_token to req_to_token_pool
|
||||||
|
self._mock_write_to_req_to_token_pool(self.batch_size, total_len, page_size)
|
||||||
|
# Add kv pool for this forward batch
|
||||||
|
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||||
|
|
||||||
|
return forward_batch
|
||||||
|
|
||||||
|
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
||||||
|
# Create constant values for the prefix cache for easy debugging
|
||||||
cache_k = torch.ones(
|
cache_k = torch.ones(
|
||||||
self.batch_size * prefix_len,
|
self.batch_size * cache_len,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@@ -278,7 +256,7 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|||||||
)
|
)
|
||||||
cache_v = (
|
cache_v = (
|
||||||
torch.ones(
|
torch.ones(
|
||||||
self.batch_size * prefix_len,
|
self.batch_size * cache_len,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@@ -290,22 +268,82 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|||||||
# Set the prefix KV cache
|
# Set the prefix KV cache
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer,
|
layer,
|
||||||
torch.arange(self.batch_size * prefix_len, device=self.device),
|
torch.arange(self.batch_size * cache_len, device=self.device),
|
||||||
cache_k,
|
cache_k,
|
||||||
cache_v,
|
cache_v,
|
||||||
layer.k_scale,
|
layer.k_scale,
|
||||||
layer.v_scale,
|
layer.v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize forward metadata before running the attention
|
def _run_attention_test(self, mode, q_len, prefix_len=0, page_size=1):
|
||||||
|
"""
|
||||||
|
Run an attention test with the specified parameters.
|
||||||
|
Args:
|
||||||
|
mode: ForwardMode.EXTEND or ForwardMode.DECODE
|
||||||
|
q_len: Length of the query sequence. For decode mode, q_len is 1.
|
||||||
|
prefix_len: Length of the prefix sequence for extend mode
|
||||||
|
page_size: Page size for the KV cache
|
||||||
|
"""
|
||||||
|
layer = self._create_attention_layer()
|
||||||
|
|
||||||
|
# Create forward batch and set up
|
||||||
|
forward_batch = self._create_forward_batch(mode, q_len, prefix_len, page_size)
|
||||||
|
|
||||||
|
# Create QKV tensors for the input
|
||||||
|
q, k, v = self._create_qkv_tensors(self.batch_size * q_len)
|
||||||
|
|
||||||
|
# KV cache for prefixed extend is prefix_len
|
||||||
|
# KV cache for decode is same as seq_len
|
||||||
|
# No KV cache for extend without prefix
|
||||||
|
if mode == ForwardMode.EXTEND:
|
||||||
|
if prefix_len > 0:
|
||||||
|
self._setup_kv_cache(forward_batch, layer, prefix_len)
|
||||||
|
else:
|
||||||
|
self._setup_kv_cache(forward_batch, layer, self.seq_len)
|
||||||
|
|
||||||
self.backend.init_forward_metadata(forward_batch)
|
self.backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
# Run forward_extend
|
if mode == ForwardMode.EXTEND:
|
||||||
|
expected_shape = (
|
||||||
|
self.batch_size * q_len,
|
||||||
|
self.num_heads * self.head_dim,
|
||||||
|
)
|
||||||
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
else:
|
||||||
|
expected_shape = (self.batch_size, self.num_heads * self.head_dim)
|
||||||
|
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
||||||
|
|
||||||
# Verify output
|
output_ref = self._run_reference_forward(
|
||||||
expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim)
|
mode, q, k, v, layer, forward_batch, expected_shape
|
||||||
self._verify_output(output, expected_shape)
|
)
|
||||||
|
|
||||||
|
self._verify_output(output, expected_shape, output_ref)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_forward_extend(self):
|
||||||
|
"""Test the standard extend operation."""
|
||||||
|
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
|
||||||
|
|
||||||
|
def test_forward_decode(self):
|
||||||
|
"""Test the decode operation with cached tokens."""
|
||||||
|
self._run_attention_test(ForwardMode.DECODE, q_len=1)
|
||||||
|
|
||||||
|
def test_forward_extend_with_prefix(self):
|
||||||
|
"""Test extending from cached prefix tokens."""
|
||||||
|
prefix_len = self.seq_len // 2
|
||||||
|
extend_len = self.seq_len - prefix_len
|
||||||
|
self._run_attention_test(
|
||||||
|
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_extend_with_page_size_greater_than_1(self):
|
||||||
|
"""Test extending from cached prefix tokens with page size greater than 1."""
|
||||||
|
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len, page_size=64)
|
||||||
|
|
||||||
|
def test_forward_decode_with_page_size_greater_than_1(self):
|
||||||
|
"""Test decode operation with page size greater than 1."""
|
||||||
|
self._run_attention_test(ForwardMode.DECODE, q_len=1, page_size=64)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
285
python/sglang/test/attention/test_flashattn_mla_backend.py
Normal file
285
python/sglang/test/attention/test_flashattn_mla_backend.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import AttentionArch
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
||||||
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class MockModelRunner:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_lora_rank,
|
||||||
|
qk_rope_head_dim,
|
||||||
|
):
|
||||||
|
attention_arch = AttentionArch.MLA
|
||||||
|
self.device = "cuda"
|
||||||
|
self.dtype = torch.float16
|
||||||
|
context_len = 2048
|
||||||
|
self.model_config = type(
|
||||||
|
"ModelConfig",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"context_len": context_len,
|
||||||
|
"attention_arch": attention_arch,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.sliding_window_size = None
|
||||||
|
|
||||||
|
batch_size = 160
|
||||||
|
# Create a proper req_to_token_pool with the req_to_token attribute
|
||||||
|
self.req_to_token_pool = type(
|
||||||
|
"TokenPool",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
# A typical max_bs * max_context_len for cuda graph decode
|
||||||
|
"size": batch_size,
|
||||||
|
# Add req_to_token attribute
|
||||||
|
"req_to_token": torch.zeros(
|
||||||
|
batch_size, context_len, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.page_size = 1
|
||||||
|
max_total_num_tokens = batch_size * context_len
|
||||||
|
self.token_to_kv_pool = MLATokenToKVPool(
|
||||||
|
size=max_total_num_tokens,
|
||||||
|
page_size=self.page_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kv_lora_rank=kv_lora_rank,
|
||||||
|
qk_rope_head_dim=qk_rope_head_dim,
|
||||||
|
layer_num=1, # only consider layer=1 for unit test
|
||||||
|
device=self.device,
|
||||||
|
enable_memory_saver=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockReqToTokenPool:
|
||||||
|
def __init__(self, batch_size, seq_len, device):
|
||||||
|
self.req_to_token = (
|
||||||
|
torch.arange(batch_size * seq_len, device=device)
|
||||||
|
.reshape(batch_size, seq_len)
|
||||||
|
.to(torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||||
|
class TestFlashAttentionMLABackend(CustomTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Test parameters
|
||||||
|
self.batch_size = 2
|
||||||
|
self.seq_len = 360
|
||||||
|
self.num_heads = 2
|
||||||
|
self.device = "cuda"
|
||||||
|
self.dtype = torch.float16
|
||||||
|
self.kv_lora_rank = 512
|
||||||
|
self.q_lora_rank = 128
|
||||||
|
self.qk_rope_head_dim = 64
|
||||||
|
self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank
|
||||||
|
# Assume no rope scaling
|
||||||
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
|
# Initialize model runner and backend
|
||||||
|
self._init_model_runner()
|
||||||
|
self.backend = FlashAttentionBackend(self.model_runner)
|
||||||
|
self.num_local_heads = 2
|
||||||
|
|
||||||
|
def _init_model_runner(self):
|
||||||
|
self.model_runner = MockModelRunner(
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
self.backend = FlashAttentionBackend(self.model_runner)
|
||||||
|
|
||||||
|
def _create_attention_layer(self):
|
||||||
|
"""Create attention layer for testing."""
|
||||||
|
self.attn_mqa = RadixAttention(
|
||||||
|
num_heads=self.num_local_heads,
|
||||||
|
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
scaling=self.scaling,
|
||||||
|
num_kv_heads=1,
|
||||||
|
layer_id=0,
|
||||||
|
v_head_dim=self.kv_lora_rank,
|
||||||
|
prefix="attn_mqa",
|
||||||
|
)
|
||||||
|
return self.attn_mqa
|
||||||
|
|
||||||
|
def _run_reference_forward(
|
||||||
|
self, mode, q, k, v, layer, forward_batch, expected_shape
|
||||||
|
):
|
||||||
|
"""Run reference forward pass using native backend."""
|
||||||
|
if mode == ForwardMode.EXTEND:
|
||||||
|
output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
else: # ForwardMode.DECODE
|
||||||
|
output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
|
||||||
|
return output.view(expected_shape)
|
||||||
|
|
||||||
|
def _verify_output(self, output, expected_shape):
|
||||||
|
"""Verify output tensor shape, dtype, and values."""
|
||||||
|
self.assertEqual(
|
||||||
|
output.shape,
|
||||||
|
expected_shape,
|
||||||
|
f"Expected shape {expected_shape}, got {output.shape}",
|
||||||
|
)
|
||||||
|
self.assertEqual(output.dtype, self.dtype)
|
||||||
|
self.assertEqual(output.device.type, "cuda")
|
||||||
|
self.assertEqual(
|
||||||
|
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_forward_batch(self, mode, q_len=None, prefix_len=0):
|
||||||
|
"""Create a forward batch for testing based on mode and lengths."""
|
||||||
|
# Default to self.seq_len if not specified
|
||||||
|
q_len = q_len or self.seq_len
|
||||||
|
|
||||||
|
if mode == ForwardMode.EXTEND:
|
||||||
|
total_len = prefix_len + q_len
|
||||||
|
out_cache_start = prefix_len * self.batch_size
|
||||||
|
out_cache_end = total_len * self.batch_size
|
||||||
|
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
input_ids=torch.randint(
|
||||||
|
0, 100, (self.batch_size, q_len), device=self.device
|
||||||
|
),
|
||||||
|
out_cache_loc=torch.arange(
|
||||||
|
out_cache_start, out_cache_end, device=self.device
|
||||||
|
),
|
||||||
|
seq_lens_sum=self.batch_size * total_len,
|
||||||
|
forward_mode=mode,
|
||||||
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
|
seq_lens=torch.tensor(
|
||||||
|
[total_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
||||||
|
extend_prefix_lens=torch.tensor(
|
||||||
|
[prefix_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
extend_prefix_lens_cpu=torch.tensor(
|
||||||
|
[prefix_len] * self.batch_size, device="cpu"
|
||||||
|
),
|
||||||
|
extend_seq_lens=torch.tensor(
|
||||||
|
[q_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
extend_seq_lens_cpu=torch.tensor(
|
||||||
|
[q_len] * self.batch_size, device="cpu"
|
||||||
|
),
|
||||||
|
attn_backend=self.backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # ForwardMode.DECODE
|
||||||
|
decode_len = q_len # typically 1 for decode mode
|
||||||
|
total_len = self.seq_len + decode_len
|
||||||
|
out_cache_start = self.batch_size * self.seq_len
|
||||||
|
out_cache_end = self.batch_size * total_len
|
||||||
|
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
input_ids=torch.randint(
|
||||||
|
0, 100, (self.batch_size, decode_len), device=self.device
|
||||||
|
),
|
||||||
|
out_cache_loc=torch.arange(
|
||||||
|
out_cache_start, out_cache_end, device=self.device
|
||||||
|
),
|
||||||
|
seq_lens_sum=self.batch_size * total_len,
|
||||||
|
forward_mode=mode,
|
||||||
|
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
||||||
|
seq_lens=torch.tensor(
|
||||||
|
[total_len] * self.batch_size, device=self.device
|
||||||
|
),
|
||||||
|
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
||||||
|
attn_backend=self.backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add token pool from model runner to forward batch
|
||||||
|
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
||||||
|
|
||||||
|
# Add KV cache from model runner to forward batch
|
||||||
|
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||||
|
|
||||||
|
return forward_batch
|
||||||
|
|
||||||
|
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
||||||
|
"""Set up KV cache with prefix tokens."""
|
||||||
|
if cache_len <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create constant values for the prefix cache for easy debugging
|
||||||
|
latent_cache = torch.ones(
|
||||||
|
self.batch_size * cache_len,
|
||||||
|
1, # latent cache has only one head in MQA
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the prefix KV cache
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer,
|
||||||
|
torch.arange(self.batch_size * cache_len, device=self.device),
|
||||||
|
latent_cache,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_attention_test(self, mode, q_len, prefix_len=0):
|
||||||
|
"""
|
||||||
|
Run an attention test with the specified parameters.
|
||||||
|
Args:
|
||||||
|
mode: ForwardMode.EXTEND or ForwardMode.DECODE
|
||||||
|
q_len: Length of the query sequence. For decode mode, q_len is 1.
|
||||||
|
prefix_len: Length of the prefix sequence for extend mode
|
||||||
|
"""
|
||||||
|
layer = self._create_attention_layer()
|
||||||
|
|
||||||
|
# Create forward batch and set up
|
||||||
|
forward_batch = self._create_forward_batch(mode, q_len, prefix_len)
|
||||||
|
|
||||||
|
# Create q, kv_compressed for testing
|
||||||
|
q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim)
|
||||||
|
kv_shape = (self.batch_size * q_len, self.qk_head_dim)
|
||||||
|
q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
|
||||||
|
kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
|
||||||
|
# v is not used for mqa, all values passed in through k
|
||||||
|
k = kv_compressed.unsqueeze(1)
|
||||||
|
v = torch.randn((1), dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
self._setup_kv_cache(forward_batch, layer, prefix_len)
|
||||||
|
|
||||||
|
self.backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
expected_shape = (
|
||||||
|
self.batch_size * q_len,
|
||||||
|
self.num_heads * self.kv_lora_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == ForwardMode.EXTEND:
|
||||||
|
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
else:
|
||||||
|
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
||||||
|
|
||||||
|
self._verify_output(output, expected_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_forward_extend(self):
|
||||||
|
"""Test the standard extend operation."""
|
||||||
|
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
|
||||||
|
|
||||||
|
def test_forward_decode(self):
|
||||||
|
"""Test the decode operation with cached tokens."""
|
||||||
|
self._run_attention_test(ForwardMode.DECODE, q_len=1)
|
||||||
|
|
||||||
|
def test_forward_extend_with_prefix(self):
|
||||||
|
"""Test extending from cached prefix tokens."""
|
||||||
|
prefix_len = self.seq_len // 2
|
||||||
|
extend_len = self.seq_len - prefix_len
|
||||||
|
self._run_attention_test(
|
||||||
|
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -28,6 +28,7 @@ suites = {
|
|||||||
TestFile("test_chunked_prefill.py", 336),
|
TestFile("test_chunked_prefill.py", 336),
|
||||||
TestFile("test_eagle_infer.py", 500),
|
TestFile("test_eagle_infer.py", 500),
|
||||||
TestFile("test_ebnf_constrained.py"),
|
TestFile("test_ebnf_constrained.py"),
|
||||||
|
TestFile("test_fa3.py", 5),
|
||||||
TestFile("test_fp8_kernel.py", 8),
|
TestFile("test_fp8_kernel.py", 8),
|
||||||
TestFile("test_embedding_openai_server.py", 36),
|
TestFile("test_embedding_openai_server.py", 36),
|
||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
|
|||||||
180
test/srt/test_fa3.py
Normal file
180
test/srt/test_fa3.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import get_device_sm, kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
|
||||||
|
"""
|
||||||
|
# Change to your own model if testing model is not public.
|
||||||
|
MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
MODEL_USED_FOR_TEST_MLA = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
|
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
|
||||||
|
DATA_PATH = None
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
|
||||||
|
class BaseFlashAttentionTest(unittest.TestCase):
|
||||||
|
"""Base class for FlashAttention tests to reduce code duplication."""
|
||||||
|
|
||||||
|
model = MODEL_USED_FOR_TEST
|
||||||
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
accuracy_threshold = 0.62
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
"""Return the arguments for the server launch. Override in subclasses."""
|
||||||
|
args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--enable-torch-compile",
|
||||||
|
"--attention-backend",
|
||||||
|
"fa3",
|
||||||
|
]
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=cls.get_server_args(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
data_path=DATA_PATH,
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
# Use the appropriate metric key based on the test class
|
||||||
|
metric_key = "accuracy"
|
||||||
|
self.assertGreater(metrics[metric_key], self.accuracy_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttention3(BaseFlashAttentionTest):
|
||||||
|
"""Test FlashAttention3 with MLA model and CUDA graph enabled."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
args = super().get_server_args()
|
||||||
|
args.extend(
|
||||||
|
[
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttention3DisableCudaGraph(BaseFlashAttentionTest):
|
||||||
|
"""Test FlashAttention3 with CUDA graph disabled."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
args = super().get_server_args()
|
||||||
|
args.extend(
|
||||||
|
[
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttention3MLA(BaseFlashAttentionTest):
|
||||||
|
"""Test FlashAttention3 with MLA."""
|
||||||
|
|
||||||
|
model = MODEL_USED_FOR_TEST_MLA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
args = super().get_server_args()
|
||||||
|
args.extend(
|
||||||
|
[
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
||||||
|
"""Test FlashAttention3 with speculative decode enabled."""
|
||||||
|
|
||||||
|
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
args = super().get_server_args()
|
||||||
|
args.extend(
|
||||||
|
[
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE3",
|
||||||
|
"--speculative-draft",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"3",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"1",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"3",
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
"""
|
||||||
|
Override the test_gsm8k to further test for average speculative accept length.
|
||||||
|
"""
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=DATA_PATH,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||||
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
self.assertGreater(avg_spec_accept_length, 1.5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user