Integrate trtllm ragged attention for prefill self-attention (#9801)
This commit is contained in:
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
||||
"v_head_dim": 512,
|
||||
"num_kv_heads": 1,
|
||||
"layer_id": 0,
|
||||
"tp_q_head_num": 128,
|
||||
"tp_k_head_num": 128,
|
||||
"prefill_head_dim": 192,
|
||||
"prefill_v_head_dim": 128,
|
||||
}
|
||||
|
||||
ROPE_BASE = 10000
|
||||
@@ -92,7 +96,7 @@ TEST_CASES = {
|
||||
"description": "Medium-scale batch",
|
||||
},
|
||||
],
|
||||
"decode_output_match": [
|
||||
"output_match": [
|
||||
{
|
||||
"name": "single_fp16",
|
||||
"batch_size": 1,
|
||||
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
config.update(test_case)
|
||||
return config
|
||||
|
||||
def _create_model_components(self, config):
|
||||
def _create_model_components(self, config, is_prefill=False):
|
||||
"""Create model runners, backends, and layer for testing."""
|
||||
# Create model runners
|
||||
model_runner_trtllm = MockModelRunner(config)
|
||||
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
||||
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
||||
|
||||
head_dim = (
|
||||
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||
if not is_prefill
|
||||
else config["prefill_head_dim"]
|
||||
)
|
||||
v_head_dim = (
|
||||
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
||||
)
|
||||
|
||||
# Create RadixAttention layer
|
||||
layer = RadixAttention(
|
||||
num_heads=config["num_attention_heads"],
|
||||
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
|
||||
head_dim=head_dim,
|
||||
scaling=model_runner_trtllm.model_config.scaling,
|
||||
num_kv_heads=config["num_kv_heads"],
|
||||
layer_id=config["layer_id"],
|
||||
v_head_dim=config["v_head_dim"],
|
||||
v_head_dim=v_head_dim,
|
||||
prefix="attn_mqa",
|
||||
)
|
||||
|
||||
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
||||
print(f"\nRunning decode output matching tests...")
|
||||
|
||||
for test_case in TEST_CASES["decode_output_match"]:
|
||||
for test_case in TEST_CASES["output_match"]:
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||
|
||||
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
self.assertIsNotNone(metadata_3.block_kv_indices)
|
||||
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
||||
|
||||
def test_prefill_output_match_self_attention(self):
|
||||
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
||||
print(f"\nRunning prefill output tests...")
|
||||
|
||||
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
||||
with self.subTest(test_case=test_case["name"]):
|
||||
print(
|
||||
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
||||
)
|
||||
|
||||
config = self._merge_config(test_case)
|
||||
batch_size = config["batch_size"]
|
||||
max_seq_len = config["max_seq_len"]
|
||||
|
||||
# Create components
|
||||
(
|
||||
model_runner_trtllm,
|
||||
model_runner_reference,
|
||||
trtllm_backend,
|
||||
reference_backend,
|
||||
layer,
|
||||
) = self._create_model_components(config, is_prefill=True)
|
||||
|
||||
# Prefill uses full sequences
|
||||
seq_lens = torch.full(
|
||||
(batch_size,), max_seq_len, device=config["device"]
|
||||
)
|
||||
|
||||
def _create_forward_batch_prefill(
|
||||
batch_size,
|
||||
seq_lens,
|
||||
extend_prefix_lens,
|
||||
backend,
|
||||
model_runner,
|
||||
config,
|
||||
):
|
||||
"""Create a forward batch for the given backend."""
|
||||
|
||||
fb = ForwardBatch(
|
||||
batch_size=batch_size,
|
||||
input_ids=torch.randint(
|
||||
0, 100, (batch_size, 1), device=config["device"]
|
||||
),
|
||||
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
||||
seq_lens_sum=int(seq_lens.sum().item()),
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
||||
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
||||
.cpu()
|
||||
.int()
|
||||
.tolist(),
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
req_pool_indices=torch.arange(
|
||||
batch_size, device=config["device"]
|
||||
),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
attn_attend_prefix_cache=False,
|
||||
mha_return_lse=False,
|
||||
attn_backend=backend,
|
||||
)
|
||||
fb.req_to_token_pool = model_runner.req_to_token_pool
|
||||
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
# Add position information for RoPE
|
||||
fb.positions = torch.arange(batch_size, device=config["device"])
|
||||
|
||||
return fb
|
||||
|
||||
# Create forward batches
|
||||
fb_trtllm = _create_forward_batch_prefill(
|
||||
batch_size,
|
||||
seq_lens.clone(),
|
||||
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
||||
trtllm_backend,
|
||||
model_runner_trtllm,
|
||||
config,
|
||||
)
|
||||
fb_reference = _create_forward_batch_prefill(
|
||||
batch_size,
|
||||
seq_lens.clone(),
|
||||
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
||||
reference_backend,
|
||||
model_runner_reference,
|
||||
config,
|
||||
)
|
||||
|
||||
# Initialize metadata for both backends
|
||||
trtllm_backend.init_forward_metadata(fb_trtllm)
|
||||
reference_backend.init_forward_metadata(fb_reference)
|
||||
|
||||
# Create Q, K, V tensors for prefill
|
||||
torch.manual_seed(config["seed_qkv"])
|
||||
|
||||
def _create_qkv_tensors_prefill(
|
||||
batch_size, seq_len, config, dtype_override=None
|
||||
):
|
||||
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
||||
device = config["device"]
|
||||
dtype = dtype_override or config["dtype"]
|
||||
|
||||
total_tokens = batch_size * seq_len
|
||||
|
||||
tp_q_head_num = config["tp_q_head_num"]
|
||||
tp_k_head_num = config["tp_k_head_num"]
|
||||
head_dim = config["prefill_head_dim"]
|
||||
v_head_dim = config["prefill_v_head_dim"]
|
||||
|
||||
q = torch.randn(
|
||||
(total_tokens, tp_q_head_num * head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
k = torch.randn(
|
||||
(total_tokens, tp_k_head_num * head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
v = torch.randn(
|
||||
(total_tokens, tp_k_head_num * v_head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Reshape as requested
|
||||
q = q.view(-1, tp_q_head_num, head_dim)
|
||||
k = k.view(-1, tp_k_head_num, head_dim)
|
||||
v = v.view(-1, tp_k_head_num, v_head_dim)
|
||||
|
||||
return q, k, v
|
||||
|
||||
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
||||
# Run prefill on both backends
|
||||
out_trtllm = trtllm_backend.forward_extend(
|
||||
q, k, v, layer, fb_trtllm, False
|
||||
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
out_reference = reference_backend.forward_extend(
|
||||
q, k, v, layer, fb_reference, False
|
||||
)
|
||||
|
||||
tolerance = config.get("tolerance", 1e-2)
|
||||
comparison_passed = compare_outputs(
|
||||
out_trtllm, out_reference, tolerance=tolerance
|
||||
)
|
||||
self.assertTrue(
|
||||
comparison_passed,
|
||||
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
||||
f"Config: {test_case['name']}, "
|
||||
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user