Sync from v0.13
This commit is contained in:
196
tests/compile/test_qk_norm_rope_fusion.py
Normal file
196
tests/compile/test_qk_norm_rope_fusion.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.qk_norm_rope_fusion import (
|
||||
FUSED_QK_ROPE_OP,
|
||||
QKNormRoPEFusionPass,
|
||||
)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RSQRT_OP = torch.ops.aten.rsqrt.default
|
||||
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
|
||||
|
||||
|
||||
class QKNormRoPETestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
vllm_config: VllmConfig,
|
||||
dtype: torch.dtype,
|
||||
prefix: str = "model.layers.0.self_attn.attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = num_heads * head_dim
|
||||
self.kv_size = num_kv_heads * head_dim
|
||||
self.rotary_dim = head_dim
|
||||
self.eps = eps
|
||||
self.dtype = dtype
|
||||
|
||||
# Register layer metadata for the fusion pass via Attention.
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=1.0 / self.head_dim**0.5,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix=prefix,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000,
|
||||
is_neox_style=is_neox,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||
|
||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
return q, k, v
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rms_norm_custom_op:
|
||||
ops.append(RMS_OP)
|
||||
else:
|
||||
ops.append(RSQRT_OP)
|
||||
|
||||
if self.enable_rope_custom_op:
|
||||
if self.rotary_emb.use_flashinfer:
|
||||
ops.append(FLASHINFER_ROTARY_OP)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [FUSED_QK_ROPE_OP]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
):
|
||||
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
||||
pytest.skip("fused_qk_norm_rope custom op not available")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
custom_ops: list[str] = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_rope_custom_op:
|
||||
custom_ops.append("+rotary_embedding")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
T = 5
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = QKNormRoPETestModel(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
eps=eps,
|
||||
is_neox=is_neox,
|
||||
vllm_config=vllm_config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
||||
|
||||
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
||||
pos = torch.arange(T, dtype=torch.long, device=qkv.device)
|
||||
qkv_unfused = qkv.clone()
|
||||
pos_unfused = pos.clone()
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
q_fused, k_fused, v_fused = model_fused(qkv, pos)
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv_unfused, 0)
|
||||
torch._dynamo.mark_dynamic(pos_unfused, 0)
|
||||
model_unfused = torch.compile(model, backend=backend_baseline)
|
||||
q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
Reference in New Issue
Block a user