Integrate trtllm ragged attention for prefill self-attention (#9801)
This commit is contained in:
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
|
|||||||
def update_wrapper(
|
def update_wrapper(
|
||||||
self,
|
self,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
disable_flashinfer_ragged: bool = False,
|
||||||
):
|
):
|
||||||
assert forward_batch.num_prefix_chunks is not None
|
assert forward_batch.num_prefix_chunks is not None
|
||||||
num_prefix_chunks = forward_batch.num_prefix_chunks
|
num_prefix_chunks = forward_batch.num_prefix_chunks
|
||||||
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
|
|||||||
causal=False,
|
causal=False,
|
||||||
)
|
)
|
||||||
# ragged prefill
|
# ragged prefill
|
||||||
self.ragged_wrapper.begin_forward(
|
if not disable_flashinfer_ragged:
|
||||||
qo_indptr=qo_indptr,
|
self.ragged_wrapper.begin_forward(
|
||||||
kv_indptr=qo_indptr,
|
qo_indptr=qo_indptr,
|
||||||
num_qo_heads=self.num_local_heads,
|
kv_indptr=qo_indptr,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_qo_heads=self.num_local_heads,
|
||||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
num_kv_heads=self.num_local_heads,
|
||||||
head_dim_vo=self.v_head_dim,
|
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||||
q_data_type=self.q_data_type,
|
head_dim_vo=self.v_head_dim,
|
||||||
causal=True,
|
q_data_type=self.q_data_type,
|
||||||
)
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
def init_mha_chunk_metadata(
|
||||||
|
self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
|
||||||
|
):
|
||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128
|
|||||||
global_zero_init_workspace_buffer = None
|
global_zero_init_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TRTLLMMLAPrefillMetadata:
|
||||||
|
"""Metadata for TRTLLM MLA prefill operations."""
|
||||||
|
|
||||||
|
max_seq_len: int
|
||||||
|
cum_seq_lens: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TRTLLMMLADecodeMetadata:
|
class TRTLLMMLADecodeMetadata:
|
||||||
"""Metadata for TRTLLM MLA decode operations."""
|
"""Metadata for TRTLLM MLA decode operations."""
|
||||||
@@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
# CUDA graph state
|
# CUDA graph state
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
self.decode_cuda_graph_kv_indices = None
|
self.decode_cuda_graph_kv_indices = None
|
||||||
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
||||||
|
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||||
|
|
||||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
max_seq_len_val,
|
max_seq_len_val,
|
||||||
)
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
self.forward_metadata = metadata
|
self.forward_decode_metadata = metadata
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize the metadata for a forward pass."""
|
"""Initialize the metadata for a forward pass."""
|
||||||
# Delegate to parent for non-decode modes.
|
# Delegate to parent for non-decode modes.
|
||||||
if not forward_batch.forward_mode.is_decode_or_idle():
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
||||||
|
cum_seq_lens_q = torch.cat(
|
||||||
|
(
|
||||||
|
torch.tensor([0], device=forward_batch.seq_lens.device),
|
||||||
|
torch.cumsum(seq_lens, dim=0),
|
||||||
|
)
|
||||||
|
).int()
|
||||||
|
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
|
||||||
|
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
|
||||||
|
max_seq_len,
|
||||||
|
cum_seq_lens_q,
|
||||||
|
seq_lens,
|
||||||
|
)
|
||||||
|
elif forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
|
# Get maximum sequence length.
|
||||||
|
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
||||||
|
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||||
|
else:
|
||||||
|
max_seq = forward_batch.seq_lens.max().item()
|
||||||
|
|
||||||
|
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||||
|
block_kv_indices = self._create_block_kv_indices(
|
||||||
|
bs,
|
||||||
|
max_seqlen_pad,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_seq_len_val = int(max_seq)
|
||||||
|
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||||
|
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
||||||
|
)
|
||||||
|
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
||||||
|
else:
|
||||||
return super().init_forward_metadata(forward_batch)
|
return super().init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
bs = forward_batch.batch_size
|
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
|
||||||
# Get maximum sequence length.
|
|
||||||
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
|
||||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
|
||||||
else:
|
|
||||||
max_seq = forward_batch.seq_lens.max().item()
|
|
||||||
|
|
||||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
|
||||||
block_kv_indices = self._create_block_kv_indices(
|
|
||||||
bs,
|
|
||||||
max_seqlen_pad,
|
|
||||||
forward_batch.req_pool_indices,
|
|
||||||
forward_batch.seq_lens,
|
|
||||||
forward_batch.seq_lens.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
max_seq_len_val = int(max_seq)
|
|
||||||
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
|
||||||
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
|
||||||
)
|
|
||||||
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
|
||||||
|
|
||||||
def quantize_and_rope_for_fp8(
|
def quantize_and_rope_for_fp8(
|
||||||
self,
|
self,
|
||||||
@@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
# Get metadata
|
# Get metadata
|
||||||
metadata = (
|
metadata = (
|
||||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||||
or self.forward_metadata
|
or self.forward_decode_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
||||||
@@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode.is_target_verify()
|
||||||
|
or forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
return super().forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
|
)
|
||||||
|
|
||||||
|
if not forward_batch.attn_attend_prefix_cache:
|
||||||
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||||
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||||
|
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
workspace_buffer=self.workspace_buffer,
|
||||||
|
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||||
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
|
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
|
bmm1_scale=layer.scaling,
|
||||||
|
bmm2_scale=1.0,
|
||||||
|
o_sf_scale=1.0,
|
||||||
|
batch_size=forward_batch.batch_size,
|
||||||
|
window_left=-1,
|
||||||
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
|
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
|
enable_pdl=False,
|
||||||
|
is_causal=True,
|
||||||
|
return_lse=forward_batch.mha_return_lse,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# replace with trtllm ragged attention once accuracy is resolved.
|
||||||
|
output = super().forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||||
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
||||||
|
|||||||
@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attention_backend == "flashinfer"
|
attention_backend == "flashinfer"
|
||||||
or attention_backend == "fa3"
|
or attention_backend == "fa3"
|
||||||
or attention_backend == "flashmla"
|
or attention_backend == "flashmla"
|
||||||
or attention_backend == "trtllm_mla"
|
|
||||||
or attention_backend == "cutlass_mla"
|
or attention_backend == "cutlass_mla"
|
||||||
):
|
):
|
||||||
# Use MHA with chunked KV cache when prefilling on long sequences.
|
# Use MHA with chunked KV cache when prefilling on long sequences.
|
||||||
@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
|
elif attention_backend == "trtllm_mla":
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
|
else:
|
||||||
|
return _dispatch_mla_subtype()
|
||||||
elif attention_backend == "aiter":
|
elif attention_backend == "aiter":
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
|
|||||||
"v_head_dim": 512,
|
"v_head_dim": 512,
|
||||||
"num_kv_heads": 1,
|
"num_kv_heads": 1,
|
||||||
"layer_id": 0,
|
"layer_id": 0,
|
||||||
|
"tp_q_head_num": 128,
|
||||||
|
"tp_k_head_num": 128,
|
||||||
|
"prefill_head_dim": 192,
|
||||||
|
"prefill_v_head_dim": 128,
|
||||||
}
|
}
|
||||||
|
|
||||||
ROPE_BASE = 10000
|
ROPE_BASE = 10000
|
||||||
@@ -92,7 +96,7 @@ TEST_CASES = {
|
|||||||
"description": "Medium-scale batch",
|
"description": "Medium-scale batch",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"decode_output_match": [
|
"output_match": [
|
||||||
{
|
{
|
||||||
"name": "single_fp16",
|
"name": "single_fp16",
|
||||||
"batch_size": 1,
|
"batch_size": 1,
|
||||||
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
config.update(test_case)
|
config.update(test_case)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _create_model_components(self, config):
|
def _create_model_components(self, config, is_prefill=False):
|
||||||
"""Create model runners, backends, and layer for testing."""
|
"""Create model runners, backends, and layer for testing."""
|
||||||
# Create model runners
|
# Create model runners
|
||||||
model_runner_trtllm = MockModelRunner(config)
|
model_runner_trtllm = MockModelRunner(config)
|
||||||
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
||||||
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
||||||
|
|
||||||
|
head_dim = (
|
||||||
|
config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
||||||
|
if not is_prefill
|
||||||
|
else config["prefill_head_dim"]
|
||||||
|
)
|
||||||
|
v_head_dim = (
|
||||||
|
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
|
||||||
|
)
|
||||||
|
|
||||||
# Create RadixAttention layer
|
# Create RadixAttention layer
|
||||||
layer = RadixAttention(
|
layer = RadixAttention(
|
||||||
num_heads=config["num_attention_heads"],
|
num_heads=config["num_attention_heads"],
|
||||||
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
|
head_dim=head_dim,
|
||||||
scaling=model_runner_trtllm.model_config.scaling,
|
scaling=model_runner_trtllm.model_config.scaling,
|
||||||
num_kv_heads=config["num_kv_heads"],
|
num_kv_heads=config["num_kv_heads"],
|
||||||
layer_id=config["layer_id"],
|
layer_id=config["layer_id"],
|
||||||
v_head_dim=config["v_head_dim"],
|
v_head_dim=v_head_dim,
|
||||||
prefix="attn_mqa",
|
prefix="attn_mqa",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
||||||
print(f"\nRunning decode output matching tests...")
|
print(f"\nRunning decode output matching tests...")
|
||||||
|
|
||||||
for test_case in TEST_CASES["decode_output_match"]:
|
for test_case in TEST_CASES["output_match"]:
|
||||||
with self.subTest(test_case=test_case["name"]):
|
with self.subTest(test_case=test_case["name"]):
|
||||||
print(f" Testing {test_case['name']}: {test_case['description']}")
|
print(f" Testing {test_case['name']}: {test_case['description']}")
|
||||||
|
|
||||||
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
|
|||||||
self.assertIsNotNone(metadata_3.block_kv_indices)
|
self.assertIsNotNone(metadata_3.block_kv_indices)
|
||||||
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
||||||
|
|
||||||
|
def test_prefill_output_match_self_attention(self):
|
||||||
|
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
|
||||||
|
print(f"\nRunning prefill output tests...")
|
||||||
|
|
||||||
|
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
|
||||||
|
with self.subTest(test_case=test_case["name"]):
|
||||||
|
print(
|
||||||
|
f"Prefill Testing {test_case['name']}: {test_case['description']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = self._merge_config(test_case)
|
||||||
|
batch_size = config["batch_size"]
|
||||||
|
max_seq_len = config["max_seq_len"]
|
||||||
|
|
||||||
|
# Create components
|
||||||
|
(
|
||||||
|
model_runner_trtllm,
|
||||||
|
model_runner_reference,
|
||||||
|
trtllm_backend,
|
||||||
|
reference_backend,
|
||||||
|
layer,
|
||||||
|
) = self._create_model_components(config, is_prefill=True)
|
||||||
|
|
||||||
|
# Prefill uses full sequences
|
||||||
|
seq_lens = torch.full(
|
||||||
|
(batch_size,), max_seq_len, device=config["device"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_forward_batch_prefill(
|
||||||
|
batch_size,
|
||||||
|
seq_lens,
|
||||||
|
extend_prefix_lens,
|
||||||
|
backend,
|
||||||
|
model_runner,
|
||||||
|
config,
|
||||||
|
):
|
||||||
|
"""Create a forward batch for the given backend."""
|
||||||
|
|
||||||
|
fb = ForwardBatch(
|
||||||
|
batch_size=batch_size,
|
||||||
|
input_ids=torch.randint(
|
||||||
|
0, 100, (batch_size, 1), device=config["device"]
|
||||||
|
),
|
||||||
|
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
||||||
|
seq_lens_sum=int(seq_lens.sum().item()),
|
||||||
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
|
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
|
||||||
|
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
|
||||||
|
.cpu()
|
||||||
|
.int()
|
||||||
|
.tolist(),
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
req_pool_indices=torch.arange(
|
||||||
|
batch_size, device=config["device"]
|
||||||
|
),
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens.cpu(),
|
||||||
|
attn_attend_prefix_cache=False,
|
||||||
|
mha_return_lse=False,
|
||||||
|
attn_backend=backend,
|
||||||
|
)
|
||||||
|
fb.req_to_token_pool = model_runner.req_to_token_pool
|
||||||
|
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||||
|
|
||||||
|
# Add position information for RoPE
|
||||||
|
fb.positions = torch.arange(batch_size, device=config["device"])
|
||||||
|
|
||||||
|
return fb
|
||||||
|
|
||||||
|
# Create forward batches
|
||||||
|
fb_trtllm = _create_forward_batch_prefill(
|
||||||
|
batch_size,
|
||||||
|
seq_lens.clone(),
|
||||||
|
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
||||||
|
trtllm_backend,
|
||||||
|
model_runner_trtllm,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
fb_reference = _create_forward_batch_prefill(
|
||||||
|
batch_size,
|
||||||
|
seq_lens.clone(),
|
||||||
|
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
|
||||||
|
reference_backend,
|
||||||
|
model_runner_reference,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize metadata for both backends
|
||||||
|
trtllm_backend.init_forward_metadata(fb_trtllm)
|
||||||
|
reference_backend.init_forward_metadata(fb_reference)
|
||||||
|
|
||||||
|
# Create Q, K, V tensors for prefill
|
||||||
|
torch.manual_seed(config["seed_qkv"])
|
||||||
|
|
||||||
|
def _create_qkv_tensors_prefill(
|
||||||
|
batch_size, seq_len, config, dtype_override=None
|
||||||
|
):
|
||||||
|
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
|
||||||
|
device = config["device"]
|
||||||
|
dtype = dtype_override or config["dtype"]
|
||||||
|
|
||||||
|
total_tokens = batch_size * seq_len
|
||||||
|
|
||||||
|
tp_q_head_num = config["tp_q_head_num"]
|
||||||
|
tp_k_head_num = config["tp_k_head_num"]
|
||||||
|
head_dim = config["prefill_head_dim"]
|
||||||
|
v_head_dim = config["prefill_v_head_dim"]
|
||||||
|
|
||||||
|
q = torch.randn(
|
||||||
|
(total_tokens, tp_q_head_num * head_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
(total_tokens, tp_k_head_num * head_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
(total_tokens, tp_k_head_num * v_head_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape as requested
|
||||||
|
q = q.view(-1, tp_q_head_num, head_dim)
|
||||||
|
k = k.view(-1, tp_k_head_num, head_dim)
|
||||||
|
v = v.view(-1, tp_k_head_num, v_head_dim)
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
|
||||||
|
# Run prefill on both backends
|
||||||
|
out_trtllm = trtllm_backend.forward_extend(
|
||||||
|
q, k, v, layer, fb_trtllm, False
|
||||||
|
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
out_reference = reference_backend.forward_extend(
|
||||||
|
q, k, v, layer, fb_reference, False
|
||||||
|
)
|
||||||
|
|
||||||
|
tolerance = config.get("tolerance", 1e-2)
|
||||||
|
comparison_passed = compare_outputs(
|
||||||
|
out_trtllm, out_reference, tolerance=tolerance
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
comparison_passed,
|
||||||
|
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
|
||||||
|
f"Config: {test_case['name']}, "
|
||||||
|
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user