diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index dca38d9bb..e6cb04875 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -548,8 +548,9 @@ class FlashAttentionBackend(AttentionBackend): # Use Flash Attention for prefill if not self.use_mla: # Do multi-head attention - kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - key_cache, value_cache = kv_cache[0], kv_cache[1] + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) key_cache = key_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim ) @@ -592,7 +593,6 @@ class FlashAttentionBackend(AttentionBackend): c_kv_cache = c_kv.view( -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim ) - q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_nope = q_all[:, :, : layer.v_head_dim] q_rope = q_all[:, :, layer.v_head_dim :] @@ -659,8 +659,10 @@ class FlashAttentionBackend(AttentionBackend): if not self.use_mla: # Do multi-head attention - kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - key_cache, value_cache = kv_cache[0], kv_cache[1] + + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) key_cache = key_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim ) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index b2375bc35..8bd410df6 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -63,10 +63,6 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - UnquantizedEmbeddingMethod, -) # Base quantization methods that don't depend on vllm BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -176,6 +172,13 @@ def get_linear_quant_method( prefix: str, linear_method_cls: type, ): + # Move import here to avoid circular import. This is only used in monkey patching + # of vllm's QuantizationConfig. + from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + UnquantizedEmbeddingMethod, + ) + cloned_config = deepcopy(config) parallel_lm_head_quantized = ( isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized diff --git a/python/sglang/test/attention/test_flashattn_backend.py b/python/sglang/test/attention/test_flashattn_backend.py index 41fd3727c..5e5ebbaf1 100644 --- a/python/sglang/test/attention/test_flashattn_backend.py +++ b/python/sglang/test/attention/test_flashattn_backend.py @@ -2,60 +2,109 @@ import unittest import torch +from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend +from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ServerArgs from sglang.test.test_utils import CustomTestCase class MockModelRunner: - model_config = type( - "ModelConfig", (), {"context_len": 2048, "is_multimodal": False} - ) - sliding_window_size = None - - def __init__(self, device="cuda"): - self.device = device - # Create a proper req_to_token_pool with the req_to_token attribute + def __init__( + self, + page_size=1, + num_heads=2, + head_dim=8, + ): + self.device = "cuda" + self.dtype = torch.float16 + attention_arch = AttentionArch.MHA + # Max batch size for the test. + max_batch_size = 160 + # Total tokens(prefix + extend + decode) in the test should not exceed this length. + max_context_len = 2048 + self.model_config = type( + "ModelConfig", + (), + { + "context_len": max_context_len, + "is_multimodal": False, + "attention_arch": attention_arch, + }, + ) + self.sliding_window_size = None + self.device = self.device + # Create a large enough req_to_token_pool to fit the test usage. self.req_to_token_pool = type( "TokenPool", (), { - "size": 160, # a typical max_bs * max_context_len for cuda graph decode + # A typical max_bs * max_context_len for cuda graph decode + "size": max_batch_size, + # Add req_to_token attribute "req_to_token": torch.zeros( - 160, 2048, dtype=torch.int32, device=device - ), # Add req_to_token attribute + max_batch_size, + max_context_len, + dtype=torch.int32, + device=self.device, + ), }, ) - - -class MockReqToTokenPool: - def __init__(self, batch_size, seq_len, device): - self.req_to_token = ( - torch.arange(batch_size * seq_len, device=device) - .reshape(batch_size, seq_len) - .to(torch.int32) + self.page_size = page_size + max_total_num_tokens = max_batch_size * max_context_len + self.token_to_kv_pool = MHATokenToKVPool( + size=max_total_num_tokens, + page_size=page_size, + dtype=self.dtype, + head_num=num_heads, + head_dim=head_dim, + layer_num=1, # only consider layer=1 for unit test + device=self.device, + enable_memory_saver=False, ) + # Required by torch native backend + self.server_args = ServerArgs(model_path="fake_model_path") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") class TestFlashAttentionBackend(CustomTestCase): def setUp(self): - """Set up test fixtures before each test method.""" - self.model_runner = MockModelRunner() - self.backend = FlashAttentionBackend(self.model_runner) - - # Common test parameters + # Test parameters self.batch_size = 2 - self.seq_len = 4 + self.seq_len = 256 self.num_heads = 2 self.head_dim = 8 self.device = "cuda" self.dtype = torch.float16 + def _init_model_runner(self, page_size=1): + self.model_runner = MockModelRunner( + page_size=page_size, + num_heads=self.num_heads, + head_dim=self.head_dim, + ) + self.backend = FlashAttentionBackend(self.model_runner) + self.ref_backend = TorchNativeAttnBackend(self.model_runner) + self.model_runner.model_config.num_attention_heads = self.num_heads + + def _mock_write_to_req_to_token_pool(self, batch_size, seq_len, page_size): + # if page_size > 1, the token pool stores the index to the page. + # so we need to multiply the index by page_size. + self.req_to_token = ( + torch.arange(0, batch_size, dtype=torch.int32, device=self.device)[:, None] + * seq_len + + torch.arange(0, seq_len, dtype=torch.int32, device=self.device)[None, :] + + page_size + ) + self.model_runner.req_to_token_pool.req_to_token[:batch_size, :seq_len] = ( + self.req_to_token + ) + def _create_attention_layer(self): - """Helper method to create an attention layer.""" + """Create attention layer for testing.""" return RadixAttention( num_heads=self.num_heads, head_dim=self.head_dim, @@ -64,47 +113,27 @@ class TestFlashAttentionBackend(CustomTestCase): layer_id=0, ) - def _create_kv_pool(self, size): - """Helper method to create a KV pool.""" - return MHATokenToKVPool( - size=size, - page_size=1, # only consider page=1 for unit test - dtype=self.dtype, - head_num=self.num_heads, - head_dim=self.head_dim, - layer_num=1, # only consider layer=1 for unit test - device=self.device, - enable_memory_saver=False, - ) - def _create_qkv_tensors(self, tokens_len): - """Helper method to create q, k, v tensors.""" + """Create q, k, v tensors for testing.""" + shape = (tokens_len, self.num_heads, self.head_dim) return ( - torch.randn( - tokens_len, - self.num_heads, - self.head_dim, - dtype=self.dtype, - device=self.device, - ), - torch.randn( - tokens_len, - self.num_heads, - self.head_dim, - dtype=self.dtype, - device=self.device, - ), - torch.randn( - tokens_len, - self.num_heads, - self.head_dim, - dtype=self.dtype, - device=self.device, - ), + torch.randn(shape, dtype=self.dtype, device=self.device), + torch.randn(shape, dtype=self.dtype, device=self.device), + torch.randn(shape, dtype=self.dtype, device=self.device), ) - def _verify_output(self, output, expected_shape): - """Helper method to verify output.""" + def _run_reference_forward( + self, mode, q, k, v, layer, forward_batch, expected_shape + ): + """Run reference forward pass using native backend.""" + if mode == ForwardMode.EXTEND: + output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch) + else: # ForwardMode.DECODE + output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch) + return output.view(expected_shape) + + def _verify_output(self, output, expected_shape, output_ref=None): + """Verify output tensor shape, dtype, and values.""" self.assertEqual( output.shape, expected_shape, @@ -116,161 +145,110 @@ class TestFlashAttentionBackend(CustomTestCase): torch.isnan(output).sum().item(), 0, "Output contains NaN values" ) - def test_forward_extend(self): - """Test the standard extend operation.""" - # Create test inputs - q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len) + if output_ref is not None: + if not torch.allclose(output, output_ref, atol=1e-1, rtol=0.0): + # Check where the values differ beyond the given tolerances + diff_mask = ~torch.isclose(output, output_ref, atol=1e-1, rtol=0.0) - # Create attention layer - layer = self._create_attention_layer() + # Find the first index where the difference occurs + if diff_mask.any(): + first_mismatch_idx = diff_mask.nonzero()[0] + print( + "First mismatch at index:", tuple(first_mismatch_idx.tolist()) + ) + print("output:", output[tuple(first_mismatch_idx.tolist())]) + print("output_ref:", output_ref[tuple(first_mismatch_idx.tolist())]) + raise AssertionError( + "Attention output is not close to the torch native backend output" + ) - # Create forward batch - forward_batch = ForwardBatch( - batch_size=self.batch_size, - input_ids=torch.randint( - 0, 100, (self.batch_size, self.seq_len), device=self.device - ), - out_cache_loc=torch.arange( - self.batch_size * self.seq_len, device=self.device - ), - seq_lens_sum=self.batch_size * self.seq_len, - forward_mode=ForwardMode.EXTEND, - req_pool_indices=torch.arange(self.batch_size, device=self.device), - seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device), - # 0 prefix, 4 extend - extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device), - extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device), - attn_backend=self.backend, - ) + def _create_forward_batch(self, mode, q_len=None, prefix_len=0, page_size=1): + """Create a forward batch for testing based on mode and lengths.""" + self._init_model_runner(page_size=page_size) - # Add token pool and KV cache - forward_batch.req_to_token_pool = MockReqToTokenPool( - self.batch_size, self.seq_len, self.device - ) - forward_batch.token_to_kv_pool = self._create_kv_pool( - self.batch_size * self.seq_len - ) + # Default to self.seq_len if not specified + q_len = q_len or self.seq_len - # Initialize forward metadata before running the attention - self.backend.init_forward_metadata(forward_batch) + if mode == ForwardMode.EXTEND: + total_len = prefix_len + q_len + out_cache_start = prefix_len * self.batch_size + out_cache_end = total_len * self.batch_size - # Run forward_extend - output = self.backend.forward_extend(q, k, v, layer, forward_batch) + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, q_len), device=self.device + ), + out_cache_loc=torch.arange( + out_cache_start, out_cache_end, device=self.device + ), + seq_lens_sum=self.batch_size * total_len, + forward_mode=mode, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor( + [total_len] * self.batch_size, device=self.device + ), + seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"), + extend_prefix_lens=torch.tensor( + [prefix_len] * self.batch_size, device=self.device + ), + extend_prefix_lens_cpu=torch.tensor( + [prefix_len] * self.batch_size, device="cpu" + ), + extend_seq_lens=torch.tensor( + [q_len] * self.batch_size, device=self.device + ), + extend_seq_lens_cpu=torch.tensor( + [q_len] * self.batch_size, device="cpu" + ), + attn_backend=self.backend, + ) + else: # ForwardMode.DECODE + decode_len = q_len # Assuming 1 for decode testing + total_len = self.seq_len + decode_len + if mode == ForwardMode.DECODE and page_size > 1: + # Get next page_size multiple of self.seq_len + out_cache_start = ( + self.batch_size * self.seq_len // page_size + 1 + ) * page_size + # out_cache_end is the start of the next block + out_cache_end = out_cache_start + decode_len * page_size + else: + out_cache_start = self.batch_size * self.seq_len + out_cache_end = self.batch_size * total_len - # Verify output - expected_shape = ( - self.batch_size * self.seq_len, - self.num_heads * self.head_dim, - ) - self._verify_output(output, expected_shape) + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, decode_len), device=self.device + ), + out_cache_loc=torch.tensor( + [out_cache_start, out_cache_end], device=self.device + ), + seq_lens_sum=self.batch_size * total_len, + forward_mode=mode, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor( + [total_len] * self.batch_size, device=self.device + ), + seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"), + attn_backend=self.backend, + ) - def test_forward_decode(self): - """Test the decode operation with cached tokens.""" - # For decode, we only have one token per sequence - decode_len = 1 - curr_seq_len = self.seq_len + decode_len + # Add token pool + forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool - # Create test inputs - q, k, v = self._create_qkv_tensors(self.batch_size * decode_len) + # Write current batch's req_to_token to req_to_token_pool + self._mock_write_to_req_to_token_pool(self.batch_size, total_len, page_size) + # Add kv pool for this forward batch + forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool - # Create attention layer - layer = self._create_attention_layer() + return forward_batch - # Create forward batch - forward_batch = ForwardBatch( - batch_size=self.batch_size, - input_ids=torch.randint( - 0, 100, (self.batch_size, decode_len), device=self.device - ), - out_cache_loc=torch.arange( - self.batch_size * self.seq_len, - self.batch_size * curr_seq_len, - device=self.device, - ), - seq_lens_sum=self.batch_size * curr_seq_len, - forward_mode=ForwardMode.DECODE, - req_pool_indices=torch.arange(self.batch_size, device=self.device), - seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device), - attn_backend=self.backend, - ) - - # Add token pool and KV cache - forward_batch.req_to_token_pool = MockReqToTokenPool( - self.batch_size, curr_seq_len, self.device - ) - forward_batch.token_to_kv_pool = self._create_kv_pool( - self.batch_size * curr_seq_len - ) - - # Pre-fill KV cache - cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - torch.arange(self.batch_size * self.seq_len, device=self.device), - cache_k, - cache_v, - layer.k_scale, - layer.v_scale, - ) - - # Initialize forward metadata before running the attention - self.backend.init_forward_metadata(forward_batch) - - # Run forward_decode - output = self.backend.forward_decode(q, k, v, layer, forward_batch) - - # Verify output - expected_shape = (self.batch_size, self.num_heads * self.head_dim) - self._verify_output(output, expected_shape) - - def test_forward_extend_with_prefix(self): - """Test extending from cached prefix tokens.""" - # Define prefix and extend lengths - prefix_len = 2 - extend_len = 2 - total_len = prefix_len + extend_len - - # Create test inputs for the extend portion - q, k, v = self._create_qkv_tensors(self.batch_size * extend_len) - - # Create attention layer - layer = self._create_attention_layer() - - # Create forward batch - forward_batch = ForwardBatch( - batch_size=self.batch_size, - input_ids=torch.randint( - 0, 100, (self.batch_size, extend_len), device=self.device - ), - out_cache_loc=torch.arange( - self.batch_size * prefix_len, - self.batch_size * total_len, - device=self.device, - ), - seq_lens_sum=self.batch_size * total_len, - forward_mode=ForwardMode.EXTEND, - req_pool_indices=torch.arange(self.batch_size, device=self.device), - seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device), - extend_prefix_lens=torch.tensor( - [prefix_len] * self.batch_size, device=self.device - ), - extend_seq_lens=torch.tensor( - [extend_len] * self.batch_size, device=self.device - ), - attn_backend=self.backend, - ) - - # Add token pool and KV cache - forward_batch.req_to_token_pool = MockReqToTokenPool( - self.batch_size, total_len, self.device - ) - forward_batch.token_to_kv_pool = self._create_kv_pool( - self.batch_size * total_len - ) - - # Pre-fill the KV cache for prefix with known values + def _setup_kv_cache(self, forward_batch, layer, cache_len): + # Create constant values for the prefix cache for easy debugging cache_k = torch.ones( - self.batch_size * prefix_len, + self.batch_size * cache_len, self.num_heads, self.head_dim, dtype=self.dtype, @@ -278,7 +256,7 @@ class TestFlashAttentionBackend(CustomTestCase): ) cache_v = ( torch.ones( - self.batch_size * prefix_len, + self.batch_size * cache_len, self.num_heads, self.head_dim, dtype=self.dtype, @@ -290,22 +268,82 @@ class TestFlashAttentionBackend(CustomTestCase): # Set the prefix KV cache forward_batch.token_to_kv_pool.set_kv_buffer( layer, - torch.arange(self.batch_size * prefix_len, device=self.device), + torch.arange(self.batch_size * cache_len, device=self.device), cache_k, cache_v, layer.k_scale, layer.v_scale, ) - # Initialize forward metadata before running the attention + def _run_attention_test(self, mode, q_len, prefix_len=0, page_size=1): + """ + Run an attention test with the specified parameters. + Args: + mode: ForwardMode.EXTEND or ForwardMode.DECODE + q_len: Length of the query sequence. For decode mode, q_len is 1. + prefix_len: Length of the prefix sequence for extend mode + page_size: Page size for the KV cache + """ + layer = self._create_attention_layer() + + # Create forward batch and set up + forward_batch = self._create_forward_batch(mode, q_len, prefix_len, page_size) + + # Create QKV tensors for the input + q, k, v = self._create_qkv_tensors(self.batch_size * q_len) + + # KV cache for prefixed extend is prefix_len + # KV cache for decode is same as seq_len + # No KV cache for extend without prefix + if mode == ForwardMode.EXTEND: + if prefix_len > 0: + self._setup_kv_cache(forward_batch, layer, prefix_len) + else: + self._setup_kv_cache(forward_batch, layer, self.seq_len) + self.backend.init_forward_metadata(forward_batch) - # Run forward_extend - output = self.backend.forward_extend(q, k, v, layer, forward_batch) + if mode == ForwardMode.EXTEND: + expected_shape = ( + self.batch_size * q_len, + self.num_heads * self.head_dim, + ) + output = self.backend.forward_extend(q, k, v, layer, forward_batch) + else: + expected_shape = (self.batch_size, self.num_heads * self.head_dim) + output = self.backend.forward_decode(q, k, v, layer, forward_batch) - # Verify output - expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim) - self._verify_output(output, expected_shape) + output_ref = self._run_reference_forward( + mode, q, k, v, layer, forward_batch, expected_shape + ) + + self._verify_output(output, expected_shape, output_ref) + + return output + + def test_forward_extend(self): + """Test the standard extend operation.""" + self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len) + + def test_forward_decode(self): + """Test the decode operation with cached tokens.""" + self._run_attention_test(ForwardMode.DECODE, q_len=1) + + def test_forward_extend_with_prefix(self): + """Test extending from cached prefix tokens.""" + prefix_len = self.seq_len // 2 + extend_len = self.seq_len - prefix_len + self._run_attention_test( + ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len + ) + + def test_forward_extend_with_page_size_greater_than_1(self): + """Test extending from cached prefix tokens with page size greater than 1.""" + self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len, page_size=64) + + def test_forward_decode_with_page_size_greater_than_1(self): + """Test decode operation with page size greater than 1.""" + self._run_attention_test(ForwardMode.DECODE, q_len=1, page_size=64) if __name__ == "__main__": diff --git a/python/sglang/test/attention/test_flashattn_mla_backend.py b/python/sglang/test/attention/test_flashattn_mla_backend.py new file mode 100644 index 000000000..ebfd0b395 --- /dev/null +++ b/python/sglang/test/attention/test_flashattn_mla_backend.py @@ -0,0 +1,285 @@ +import unittest + +import torch + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend +from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.test.test_utils import CustomTestCase + + +class MockModelRunner: + def __init__( + self, + kv_lora_rank, + qk_rope_head_dim, + ): + attention_arch = AttentionArch.MLA + self.device = "cuda" + self.dtype = torch.float16 + context_len = 2048 + self.model_config = type( + "ModelConfig", + (), + { + "context_len": context_len, + "attention_arch": attention_arch, + }, + ) + self.sliding_window_size = None + + batch_size = 160 + # Create a proper req_to_token_pool with the req_to_token attribute + self.req_to_token_pool = type( + "TokenPool", + (), + { + # A typical max_bs * max_context_len for cuda graph decode + "size": batch_size, + # Add req_to_token attribute + "req_to_token": torch.zeros( + batch_size, context_len, dtype=torch.int32, device=self.device + ), + }, + ) + self.page_size = 1 + max_total_num_tokens = batch_size * context_len + self.token_to_kv_pool = MLATokenToKVPool( + size=max_total_num_tokens, + page_size=self.page_size, + dtype=self.dtype, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + layer_num=1, # only consider layer=1 for unit test + device=self.device, + enable_memory_saver=False, + ) + + +class MockReqToTokenPool: + def __init__(self, batch_size, seq_len, device): + self.req_to_token = ( + torch.arange(batch_size * seq_len, device=device) + .reshape(batch_size, seq_len) + .to(torch.int32) + ) + + +@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") +class TestFlashAttentionMLABackend(CustomTestCase): + def setUp(self): + # Test parameters + self.batch_size = 2 + self.seq_len = 360 + self.num_heads = 2 + self.device = "cuda" + self.dtype = torch.float16 + self.kv_lora_rank = 512 + self.q_lora_rank = 128 + self.qk_rope_head_dim = 64 + self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank + # Assume no rope scaling + self.scaling = self.qk_head_dim**-0.5 + # Initialize model runner and backend + self._init_model_runner() + self.backend = FlashAttentionBackend(self.model_runner) + self.num_local_heads = 2 + + def _init_model_runner(self): + self.model_runner = MockModelRunner( + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + ) + self.backend = FlashAttentionBackend(self.model_runner) + + def _create_attention_layer(self): + """Create attention layer for testing.""" + self.attn_mqa = RadixAttention( + num_heads=self.num_local_heads, + head_dim=self.kv_lora_rank + self.qk_rope_head_dim, + scaling=self.scaling, + num_kv_heads=1, + layer_id=0, + v_head_dim=self.kv_lora_rank, + prefix="attn_mqa", + ) + return self.attn_mqa + + def _run_reference_forward( + self, mode, q, k, v, layer, forward_batch, expected_shape + ): + """Run reference forward pass using native backend.""" + if mode == ForwardMode.EXTEND: + output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch) + else: # ForwardMode.DECODE + output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch) + return output.view(expected_shape) + + def _verify_output(self, output, expected_shape): + """Verify output tensor shape, dtype, and values.""" + self.assertEqual( + output.shape, + expected_shape, + f"Expected shape {expected_shape}, got {output.shape}", + ) + self.assertEqual(output.dtype, self.dtype) + self.assertEqual(output.device.type, "cuda") + self.assertEqual( + torch.isnan(output).sum().item(), 0, "Output contains NaN values" + ) + + def _create_forward_batch(self, mode, q_len=None, prefix_len=0): + """Create a forward batch for testing based on mode and lengths.""" + # Default to self.seq_len if not specified + q_len = q_len or self.seq_len + + if mode == ForwardMode.EXTEND: + total_len = prefix_len + q_len + out_cache_start = prefix_len * self.batch_size + out_cache_end = total_len * self.batch_size + + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, q_len), device=self.device + ), + out_cache_loc=torch.arange( + out_cache_start, out_cache_end, device=self.device + ), + seq_lens_sum=self.batch_size * total_len, + forward_mode=mode, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor( + [total_len] * self.batch_size, device=self.device + ), + seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"), + extend_prefix_lens=torch.tensor( + [prefix_len] * self.batch_size, device=self.device + ), + extend_prefix_lens_cpu=torch.tensor( + [prefix_len] * self.batch_size, device="cpu" + ), + extend_seq_lens=torch.tensor( + [q_len] * self.batch_size, device=self.device + ), + extend_seq_lens_cpu=torch.tensor( + [q_len] * self.batch_size, device="cpu" + ), + attn_backend=self.backend, + ) + + else: # ForwardMode.DECODE + decode_len = q_len # typically 1 for decode mode + total_len = self.seq_len + decode_len + out_cache_start = self.batch_size * self.seq_len + out_cache_end = self.batch_size * total_len + + forward_batch = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint( + 0, 100, (self.batch_size, decode_len), device=self.device + ), + out_cache_loc=torch.arange( + out_cache_start, out_cache_end, device=self.device + ), + seq_lens_sum=self.batch_size * total_len, + forward_mode=mode, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=torch.tensor( + [total_len] * self.batch_size, device=self.device + ), + seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"), + attn_backend=self.backend, + ) + + # Add token pool from model runner to forward batch + forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool + + # Add KV cache from model runner to forward batch + forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool + + return forward_batch + + def _setup_kv_cache(self, forward_batch, layer, cache_len): + """Set up KV cache with prefix tokens.""" + if cache_len <= 0: + return + + # Create constant values for the prefix cache for easy debugging + latent_cache = torch.ones( + self.batch_size * cache_len, + 1, # latent cache has only one head in MQA + self.kv_lora_rank + self.qk_rope_head_dim, + dtype=self.dtype, + device=self.device, + ) + + # Set the prefix KV cache + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + torch.arange(self.batch_size * cache_len, device=self.device), + latent_cache, + None, + ) + + def _run_attention_test(self, mode, q_len, prefix_len=0): + """ + Run an attention test with the specified parameters. + Args: + mode: ForwardMode.EXTEND or ForwardMode.DECODE + q_len: Length of the query sequence. For decode mode, q_len is 1. + prefix_len: Length of the prefix sequence for extend mode + """ + layer = self._create_attention_layer() + + # Create forward batch and set up + forward_batch = self._create_forward_batch(mode, q_len, prefix_len) + + # Create q, kv_compressed for testing + q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim) + kv_shape = (self.batch_size * q_len, self.qk_head_dim) + q = torch.randn(q_shape, dtype=self.dtype, device=self.device) + kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device) + # v is not used for mqa, all values passed in through k + k = kv_compressed.unsqueeze(1) + v = torch.randn((1), dtype=self.dtype, device=self.device) + + self._setup_kv_cache(forward_batch, layer, prefix_len) + + self.backend.init_forward_metadata(forward_batch) + + expected_shape = ( + self.batch_size * q_len, + self.num_heads * self.kv_lora_rank, + ) + + if mode == ForwardMode.EXTEND: + output = self.backend.forward_extend(q, k, v, layer, forward_batch) + else: + output = self.backend.forward_decode(q, k, v, layer, forward_batch) + + self._verify_output(output, expected_shape) + return output + + def test_forward_extend(self): + """Test the standard extend operation.""" + self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len) + + def test_forward_decode(self): + """Test the decode operation with cached tokens.""" + self._run_attention_test(ForwardMode.DECODE, q_len=1) + + def test_forward_extend_with_prefix(self): + """Test extending from cached prefix tokens.""" + prefix_len = self.seq_len // 2 + extend_len = self.seq_len - prefix_len + self._run_attention_test( + ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 84e6ff5ea..c2de835b6 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -28,6 +28,7 @@ suites = { TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), + TestFile("test_fa3.py", 5), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), TestFile("test_hidden_states.py", 55), diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py new file mode 100644 index 000000000..68b6432fc --- /dev/null +++ b/test/srt/test_fa3.py @@ -0,0 +1,180 @@ +import unittest +from types import SimpleNamespace + +import requests +import torch + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + +""" +Integration test for python/sglang/srt/layers/attention/flashattention_backend.py +""" +# Change to your own model if testing model is not public. +MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST +MODEL_USED_FOR_TEST_MLA = DEFAULT_MLA_MODEL_NAME_FOR_TEST +# Setting data path to None uses default data path in few_shot_gsm8k eval test. +DATA_PATH = None + + +@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") +class BaseFlashAttentionTest(unittest.TestCase): + """Base class for FlashAttention tests to reduce code duplication.""" + + model = MODEL_USED_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + accuracy_threshold = 0.62 + + @classmethod + def get_server_args(cls): + """Return the arguments for the server launch. Override in subclasses.""" + args = [ + "--trust-remote-code", + "--enable-torch-compile", + "--attention-backend", + "fa3", + ] + return args + + @classmethod + def setUpClass(cls): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + data_path=DATA_PATH, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + # Use the appropriate metric key based on the test class + metric_key = "accuracy" + self.assertGreater(metrics[metric_key], self.accuracy_threshold) + + +class TestFlashAttention3(BaseFlashAttentionTest): + """Test FlashAttention3 with MLA model and CUDA graph enabled.""" + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + ] + ) + return args + + +class TestFlashAttention3DisableCudaGraph(BaseFlashAttentionTest): + """Test FlashAttention3 with CUDA graph disabled.""" + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--disable-cuda-graph", + ] + ) + return args + + +class TestFlashAttention3MLA(BaseFlashAttentionTest): + """Test FlashAttention3 with MLA.""" + + model = MODEL_USED_FOR_TEST_MLA + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + ] + ) + return args + + +class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): + """Test FlashAttention3 with speculative decode enabled.""" + + model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "3", + "--dtype", + "float16", + ] + ) + return args + + def test_gsm8k(self): + """ + Override the test_gsm8k to further test for average speculative accept length. + """ + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=DATA_PATH, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 1.5) + + +if __name__ == "__main__": + unittest.main()