[Feature] Add a test for Layer-wise Prefill (#8231)
Signed-off-by: jason-fxz <jason341132@qq.com>
This commit is contained in:
299
test/srt/test_forward_split_prefill.py
Normal file
299
test/srt/test_forward_split_prefill.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Test forward_split_prefill functionality.
|
||||
|
||||
Usage:
|
||||
python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill
|
||||
or
|
||||
python3 test_forward_split_prefill.py
|
||||
"""
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||
|
||||
|
||||
class TestForwardSplitPrefill(CustomTestCase):
|
||||
"""Test cases for forward_split_prefill functionality."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up the test environment once for all tests."""
|
||||
cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.tp_size = 1
|
||||
cls.device = "cuda"
|
||||
|
||||
# Initialize server args
|
||||
cls.server_args = ServerArgs(
|
||||
model_path=cls.model_path,
|
||||
tokenizer_path=cls.model_path,
|
||||
host="127.0.0.1",
|
||||
disable_cuda_graph=True, # Disable CUDA graph for testing split prefill
|
||||
disable_hybrid_swa_memory=True,
|
||||
port=30000,
|
||||
tp_size=cls.tp_size,
|
||||
mem_fraction_static=0.8,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
cls.port_args = PortArgs.init_new(cls.server_args)
|
||||
|
||||
# Load model and tokenizer
|
||||
cls.model_config = ModelConfig.from_server_args(cls.server_args)
|
||||
cls.model_runner = ModelRunner(
|
||||
model_config=cls.model_config,
|
||||
mem_fraction_static=cls.server_args.mem_fraction_static,
|
||||
gpu_id=0,
|
||||
tp_rank=0,
|
||||
tp_size=cls.tp_size,
|
||||
pp_rank=0,
|
||||
pp_size=1,
|
||||
nccl_port=cls.port_args.nccl_port,
|
||||
server_args=cls.server_args,
|
||||
)
|
||||
|
||||
cls.tokenizer = get_tokenizer(
|
||||
cls.server_args.tokenizer_path,
|
||||
tokenizer_mode=cls.server_args.tokenizer_mode,
|
||||
trust_remote_code=cls.server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
print(
|
||||
f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}"
|
||||
)
|
||||
|
||||
def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True):
|
||||
"""Prepare a test batch for split prefill testing."""
|
||||
# Create synthetic input
|
||||
input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_new_tokens=8,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(batch_size):
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text="",
|
||||
origin_input_ids=list(input_ids[i]),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
|
||||
tree_cache=None,
|
||||
model_config=self.model_config,
|
||||
enable_overlap=False,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
enable_custom_logit_processor=False,
|
||||
)
|
||||
if is_split_prefill:
|
||||
batch.prepare_for_split_prefill()
|
||||
else:
|
||||
batch.prepare_for_extend()
|
||||
|
||||
# Create forward batch
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
|
||||
return forward_batch
|
||||
|
||||
def test_split_prefill_functionality(self):
|
||||
"""Test that split prefill can complete successfully."""
|
||||
print("\n=== Testing split prefill functionality ===")
|
||||
|
||||
forward_batch = self.prepare_test_batch(batch_size=2, input_len=64)
|
||||
|
||||
# Reset split index
|
||||
forward_batch.split_index = 0
|
||||
|
||||
# Test split prefill in chunks
|
||||
num_layers = self.model_config.num_hidden_layers
|
||||
chunk_size = max(1, num_layers // 4) # Split into 4 chunks
|
||||
|
||||
results = []
|
||||
split_count = 0
|
||||
|
||||
while forward_batch.split_index < num_layers:
|
||||
print(
|
||||
f"Processing split {split_count}, split_index: {forward_batch.split_index}"
|
||||
)
|
||||
|
||||
result = self.model_runner.forward_split_prefill(
|
||||
forward_batch=forward_batch,
|
||||
reinit_attn_backend=(split_count == 0),
|
||||
forward_count=chunk_size,
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
split_count += 1
|
||||
|
||||
# Verify split_index is updated correctly
|
||||
expected_next_index = min(split_count * chunk_size, num_layers)
|
||||
self.assertEqual(forward_batch.split_index, expected_next_index)
|
||||
|
||||
# The last result should contain logits
|
||||
self.assertIsNotNone(results[-1], "Final split should return logits")
|
||||
print(f"Split prefill completed in {split_count} splits")
|
||||
|
||||
def test_split_prefill_vs_normal_prefill(self):
|
||||
"""Test that split prefill produces the same results as normal prefill."""
|
||||
print("\n=== Testing split prefill vs normal prefill consistency ===")
|
||||
|
||||
forward_batch_normal = self.prepare_test_batch(
|
||||
batch_size=2, input_len=128, is_split_prefill=False
|
||||
)
|
||||
forward_batch_split = self.prepare_test_batch(
|
||||
batch_size=2, input_len=128, is_split_prefill=True
|
||||
)
|
||||
|
||||
# Ensure same input
|
||||
forward_batch_split.input_ids = forward_batch_normal.input_ids.clone()
|
||||
forward_batch_split.positions = forward_batch_normal.positions.clone()
|
||||
|
||||
# Method 1: Normal extend (prefill)
|
||||
print("Running normal extend (prefill)...")
|
||||
normal_result = self.model_runner.forward_extend(forward_batch_normal)
|
||||
|
||||
# Method 2: Split prefill
|
||||
print("Running split prefill...")
|
||||
num_layers = self.model_config.num_hidden_layers
|
||||
chunk_size = max(1, num_layers // 3) # Split into 3 chunks
|
||||
|
||||
split_result = None
|
||||
|
||||
while forward_batch_split.split_index < num_layers:
|
||||
result = self.model_runner.forward_split_prefill(
|
||||
forward_batch=forward_batch_split,
|
||||
forward_count=chunk_size,
|
||||
)
|
||||
if result is not None:
|
||||
split_result = result
|
||||
|
||||
# Compare results
|
||||
self.assertIsNotNone(normal_result, "Normal prefill should return result")
|
||||
self.assertIsNotNone(split_result, "Split prefill should return result")
|
||||
|
||||
# Compare logits shapes
|
||||
self.assertEqual(
|
||||
normal_result.next_token_logits.shape,
|
||||
split_result.next_token_logits.shape,
|
||||
"Logits shapes should match",
|
||||
)
|
||||
|
||||
# Compare logits values (should be very close due to same computation)
|
||||
# Use a larger tolerance for numerical differences in split computation
|
||||
torch.testing.assert_close(
|
||||
normal_result.next_token_logits,
|
||||
split_result.next_token_logits,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
msg="Split prefill and normal prefill should produce similar logits",
|
||||
)
|
||||
|
||||
print("✓ Split prefill and normal prefill produce consistent results")
|
||||
|
||||
def test_split_prefill_different_chunk_sizes(self):
|
||||
"""Test split prefill with different chunk sizes."""
|
||||
print("\n=== Testing split prefill with different chunk sizes ===")
|
||||
|
||||
num_layers = self.model_config.num_hidden_layers
|
||||
chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers]
|
||||
|
||||
# Prepare identical batches for each test
|
||||
base_batch = self.prepare_test_batch(batch_size=1, input_len=16)
|
||||
base_input_ids = base_batch.input_ids.clone()
|
||||
base_positions = base_batch.positions.clone()
|
||||
|
||||
results = []
|
||||
|
||||
for chunk_size in chunk_sizes:
|
||||
if chunk_size > num_layers:
|
||||
continue
|
||||
|
||||
print(f"Testing chunk size: {chunk_size}")
|
||||
|
||||
# Prepare fresh batch
|
||||
forward_batch = self.prepare_test_batch(batch_size=1, input_len=16)
|
||||
forward_batch.input_ids = base_input_ids.clone()
|
||||
forward_batch.positions = base_positions.clone()
|
||||
forward_batch.split_index = 0
|
||||
|
||||
# Run split prefill
|
||||
split_result = None
|
||||
|
||||
while forward_batch.split_index < num_layers:
|
||||
result = self.model_runner.forward_split_prefill(
|
||||
forward_batch=forward_batch,
|
||||
forward_count=chunk_size,
|
||||
)
|
||||
if result is not None:
|
||||
split_result = result
|
||||
|
||||
self.assertIsNotNone(
|
||||
split_result,
|
||||
f"Split prefill should succeed with chunk_size={chunk_size}",
|
||||
)
|
||||
results.append(split_result)
|
||||
|
||||
# Compare all results should be identical (same input, same computation)
|
||||
if len(results) > 1:
|
||||
for i, result in enumerate(results[1:], 1):
|
||||
torch.testing.assert_close(
|
||||
results[0].next_token_logits,
|
||||
result.next_token_logits,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})",
|
||||
)
|
||||
|
||||
print("✓ All chunk sizes produce consistent results")
|
||||
|
||||
def test_split_prefill_edge_cases(self):
|
||||
"""Test edge cases for split prefill."""
|
||||
print("\n=== Testing split prefill edge cases ===")
|
||||
|
||||
# Test with single layer chunks
|
||||
forward_batch = self.prepare_test_batch(batch_size=1, input_len=8)
|
||||
|
||||
# Process one layer at a time
|
||||
num_layers = self.model_config.num_hidden_layers
|
||||
for layer_idx in range(num_layers):
|
||||
result = self.model_runner.forward_split_prefill(
|
||||
forward_batch=forward_batch,
|
||||
reinit_attn_backend=(layer_idx == 0),
|
||||
forward_count=1, # One layer at a time
|
||||
)
|
||||
|
||||
if layer_idx == num_layers - 1:
|
||||
# Last layer should return result
|
||||
self.assertIsNotNone(result, "Last layer should return logits")
|
||||
else:
|
||||
# Intermediate layers should return None
|
||||
self.assertIsNone(result, f"Layer {layer_idx} should return None")
|
||||
|
||||
print("✓ Single layer processing works correctly")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user