From 7b68d271119655e993232b4785a9cec26e0180ec Mon Sep 17 00:00:00 2001 From: Xiaoze Fan Date: Mon, 21 Jul 2025 22:06:15 +0800 Subject: [PATCH] [Feature] Add a test for Layer-wise Prefill (#8231) Signed-off-by: jason-fxz --- test/srt/test_forward_split_prefill.py | 299 +++++++++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 test/srt/test_forward_split_prefill.py diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py new file mode 100644 index 000000000..bbd247583 --- /dev/null +++ b/test/srt/test_forward_split_prefill.py @@ -0,0 +1,299 @@ +""" +Test forward_split_prefill functionality. + +Usage: +python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill +or +python3 test_forward_split_prefill.py +""" + +import time +import unittest + +import numpy as np +import torch + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase + + +class TestForwardSplitPrefill(CustomTestCase): + """Test cases for forward_split_prefill functionality.""" + + @classmethod + def setUpClass(cls): + """Set up the test environment once for all tests.""" + cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.tp_size = 1 + cls.device = "cuda" + + # Initialize server args + cls.server_args = ServerArgs( + model_path=cls.model_path, + tokenizer_path=cls.model_path, + host="127.0.0.1", + disable_cuda_graph=True, # Disable CUDA graph for testing split prefill + disable_hybrid_swa_memory=True, + port=30000, + tp_size=cls.tp_size, + mem_fraction_static=0.8, + trust_remote_code=True, + ) + + cls.port_args = PortArgs.init_new(cls.server_args) + + # Load model and tokenizer + cls.model_config = ModelConfig.from_server_args(cls.server_args) + cls.model_runner = ModelRunner( + model_config=cls.model_config, + mem_fraction_static=cls.server_args.mem_fraction_static, + gpu_id=0, + tp_rank=0, + tp_size=cls.tp_size, + pp_rank=0, + pp_size=1, + nccl_port=cls.port_args.nccl_port, + server_args=cls.server_args, + ) + + cls.tokenizer = get_tokenizer( + cls.server_args.tokenizer_path, + tokenizer_mode=cls.server_args.tokenizer_mode, + trust_remote_code=cls.server_args.trust_remote_code, + ) + + print( + f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}" + ) + + def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True): + """Prepare a test batch for split prefill testing.""" + # Create synthetic input + input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32) + + sampling_params = SamplingParams( + temperature=0.0, + max_new_tokens=8, + ) + + reqs = [] + for i in range(batch_size): + req = Req( + rid=i, + origin_input_text="", + origin_input_ids=list(input_ids[i]), + sampling_params=sampling_params, + ) + req.prefix_indices = [] + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + reqs.append(req) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=None, + model_config=self.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + enable_custom_logit_processor=False, + ) + if is_split_prefill: + batch.prepare_for_split_prefill() + else: + batch.prepare_for_extend() + + # Create forward batch + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + + return forward_batch + + def test_split_prefill_functionality(self): + """Test that split prefill can complete successfully.""" + print("\n=== Testing split prefill functionality ===") + + forward_batch = self.prepare_test_batch(batch_size=2, input_len=64) + + # Reset split index + forward_batch.split_index = 0 + + # Test split prefill in chunks + num_layers = self.model_config.num_hidden_layers + chunk_size = max(1, num_layers // 4) # Split into 4 chunks + + results = [] + split_count = 0 + + while forward_batch.split_index < num_layers: + print( + f"Processing split {split_count}, split_index: {forward_batch.split_index}" + ) + + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch, + reinit_attn_backend=(split_count == 0), + forward_count=chunk_size, + ) + + results.append(result) + split_count += 1 + + # Verify split_index is updated correctly + expected_next_index = min(split_count * chunk_size, num_layers) + self.assertEqual(forward_batch.split_index, expected_next_index) + + # The last result should contain logits + self.assertIsNotNone(results[-1], "Final split should return logits") + print(f"Split prefill completed in {split_count} splits") + + def test_split_prefill_vs_normal_prefill(self): + """Test that split prefill produces the same results as normal prefill.""" + print("\n=== Testing split prefill vs normal prefill consistency ===") + + forward_batch_normal = self.prepare_test_batch( + batch_size=2, input_len=128, is_split_prefill=False + ) + forward_batch_split = self.prepare_test_batch( + batch_size=2, input_len=128, is_split_prefill=True + ) + + # Ensure same input + forward_batch_split.input_ids = forward_batch_normal.input_ids.clone() + forward_batch_split.positions = forward_batch_normal.positions.clone() + + # Method 1: Normal extend (prefill) + print("Running normal extend (prefill)...") + normal_result = self.model_runner.forward_extend(forward_batch_normal) + + # Method 2: Split prefill + print("Running split prefill...") + num_layers = self.model_config.num_hidden_layers + chunk_size = max(1, num_layers // 3) # Split into 3 chunks + + split_result = None + + while forward_batch_split.split_index < num_layers: + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch_split, + forward_count=chunk_size, + ) + if result is not None: + split_result = result + + # Compare results + self.assertIsNotNone(normal_result, "Normal prefill should return result") + self.assertIsNotNone(split_result, "Split prefill should return result") + + # Compare logits shapes + self.assertEqual( + normal_result.next_token_logits.shape, + split_result.next_token_logits.shape, + "Logits shapes should match", + ) + + # Compare logits values (should be very close due to same computation) + # Use a larger tolerance for numerical differences in split computation + torch.testing.assert_close( + normal_result.next_token_logits, + split_result.next_token_logits, + rtol=1e-3, + atol=1e-3, + msg="Split prefill and normal prefill should produce similar logits", + ) + + print("✓ Split prefill and normal prefill produce consistent results") + + def test_split_prefill_different_chunk_sizes(self): + """Test split prefill with different chunk sizes.""" + print("\n=== Testing split prefill with different chunk sizes ===") + + num_layers = self.model_config.num_hidden_layers + chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers] + + # Prepare identical batches for each test + base_batch = self.prepare_test_batch(batch_size=1, input_len=16) + base_input_ids = base_batch.input_ids.clone() + base_positions = base_batch.positions.clone() + + results = [] + + for chunk_size in chunk_sizes: + if chunk_size > num_layers: + continue + + print(f"Testing chunk size: {chunk_size}") + + # Prepare fresh batch + forward_batch = self.prepare_test_batch(batch_size=1, input_len=16) + forward_batch.input_ids = base_input_ids.clone() + forward_batch.positions = base_positions.clone() + forward_batch.split_index = 0 + + # Run split prefill + split_result = None + + while forward_batch.split_index < num_layers: + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch, + forward_count=chunk_size, + ) + if result is not None: + split_result = result + + self.assertIsNotNone( + split_result, + f"Split prefill should succeed with chunk_size={chunk_size}", + ) + results.append(split_result) + + # Compare all results should be identical (same input, same computation) + if len(results) > 1: + for i, result in enumerate(results[1:], 1): + torch.testing.assert_close( + results[0].next_token_logits, + result.next_token_logits, + rtol=1e-3, + atol=1e-3, + msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})", + ) + + print("✓ All chunk sizes produce consistent results") + + def test_split_prefill_edge_cases(self): + """Test edge cases for split prefill.""" + print("\n=== Testing split prefill edge cases ===") + + # Test with single layer chunks + forward_batch = self.prepare_test_batch(batch_size=1, input_len=8) + + # Process one layer at a time + num_layers = self.model_config.num_hidden_layers + for layer_idx in range(num_layers): + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch, + reinit_attn_backend=(layer_idx == 0), + forward_count=1, # One layer at a time + ) + + if layer_idx == num_layers - 1: + # Last layer should return result + self.assertIsNotNone(result, "Last layer should return logits") + else: + # Intermediate layers should return None + self.assertIsNone(result, f"Layer {layer_idx} should return None") + + print("✓ Single layer processing works correctly") + + +if __name__ == "__main__": + unittest.main()