diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 0509c4cf..61327fff 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -1,4 +1,8 @@ e2e-singlecard: + - name: tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py + estimated_time: 80 + - name: tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py + estimated_time: 80 - name: tests/e2e/singlecard/test_auto_fit_max_mode_len.py estimated_time: 25 - name: tests/e2e/singlecard/test_aclgraph_accuracy.py diff --git a/tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py b/tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py new file mode 100644 index 00000000..3e9514e1 --- /dev/null +++ b/tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py @@ -0,0 +1,290 @@ +import copy + +import pytest +import torch +import torch.nn as nn +import torch_npu +import torchair +import vllm.config +from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment +from vllm.utils.system_utils import update_environment_variables + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.compilation.npugraph_ex_passes.graphex_norm_quant_fusion_pass import ( + GraphEXAddRMSNormQuantPattern, + GraphEXAddRMSNormQuantPatternWithBias, + GraphEXAddRMSNormQuantSPPattern, + GraphEXAddRMSNormQuantSPPatternWithBias, +) + + +def find_op(gm, op_default): + return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes) + + +def create_pattern_wrapper(assert_func): + original_func = torchair.npu_fx_compiler._optimize_fx + + def wrapper(gm, example_inputs=None, config=None): + ret = original_func(gm, example_inputs, config) + graph_after = copy.deepcopy(gm) + assert_func(graph_after) + return ret + + return wrapper + + +class ModelWithoutBias(nn.Module): + """ + A minimal test model that simulates the pattern: + AddRMSNorm → Quantization (without bias) + """ + + def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device)) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Quantize the normalized output to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps) + + quantized_output = torch.ops.vllm.quantize( + norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset + ) + + return quantized_output, new_residual + + +class ModelWithBias(nn.Module): + """ + A test model that simulates the pattern: + AddRMSNorm → Add Bias → Quantization (with bias) + """ + + def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device)) + self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device)) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Add bias + 3. Quantize to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps) + + # Add bias + norm_output_with_bias = norm_output + self.bias + + quantized_output = torch.ops.vllm.quantize( + norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset + ) + + return quantized_output, new_residual + + +class ModelSPWithoutBias(nn.Module): + """ + A minimal test model that simulates the pattern: + AddRMSNorm → maybe_allgather → Quantization (without bias) + """ + + def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device)) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Perform a fake maybe_all_gather_and_maybe_unpad + 3. Quantize the normalized output to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps) + + norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output, True) + + quantized_output = torch.ops.vllm.quantize( + norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset + ) + + return quantized_output, new_residual + + +class ModelSPWithBias(nn.Module): + """ + A minimal test model that simulates the pattern: + AddRMSNorm → Add bias → maybe_allgather → Quantization (without bias) + """ + + def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device)) + self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device)) + self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device) + self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Add bias + 3. Perform a fake maybe_all_gather_and_maybe_unpad + 4. Quantize the normalized output to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps) + + # Add bias + norm_output_with_bias = norm_output + self.bias + + norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output_with_bias, True) + + quantized_output = torch.ops.vllm.quantize( + norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset + ) + + return quantized_output, new_residual + + +def assert_addrmsnorm_quant(after_gm, expect_fused=True, use_bias=False, sp_enable=False): + check_rules = [ + (torch.ops.npu.npu_add_rms_norm_quant.default, expect_fused), + (torch.ops.npu.npu_add_rms_norm.default, not expect_fused), + (torch.ops.npu.npu_quantize.default, not expect_fused), + ] + if use_bias: + check_rules.append((torch.ops.aten.add.Tensor, not expect_fused)) + if sp_enable: + check_rules.append((torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, expect_fused)) + for torch_op, expect_exist in check_rules: + found = find_op(after_gm, torch_op) + if expect_exist: + assert found, f"Expected operator '{torch_op}' but not find" + else: + assert not found, f"Not expected operator '{torch_op}' but find" + + +_registered_patterns = set() + + +def register_pattern_safe(pattern_class, vllm_config, eps, pattern_key): + global _registered_patterns + if pattern_key in _registered_patterns: + print(f"Pattern {pattern_key} already registered, skipping...") + return None + + pattern = pattern_class(vllm_config=vllm_config, eps=eps) + try: + pattern.register() + _registered_patterns.add(pattern_key) + print(f"Successfully registered pattern: {pattern_key}") + except RuntimeError as e: + if "Duplicate pattern" in str(e): + print(f"Pattern {pattern_key} already exists (caught from RuntimeError), skipping...") + _registered_patterns.add(pattern_key) + else: + raise e + return pattern + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5]) +@pytest.mark.parametrize("use_bias", [False, True]) +@pytest.mark.parametrize("sp_enable", [False, True]) +def test_rmsnorm_quant_fusion( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + use_bias: bool, + sp_enable: bool, +): + vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) + with vllm.config.set_current_vllm_config(vllm_config): + update_environment_variables( + { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + init_distributed_environment() + ensure_model_parallel_initialized(1, 1) + + with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config): + if use_bias: + if sp_enable: + model = ModelSPWithBias(hidden_size, dtype, eps, device="npu") + register_pattern_safe( + GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config, eps, "GraphEXAddRMSNormQuantSPPatternWithBias" + ) + else: + model = ModelWithBias(hidden_size, dtype, eps, device="npu") + register_pattern_safe( + GraphEXAddRMSNormQuantPatternWithBias, vllm_config, eps, "GraphEXAddRMSNormQuantPatternWithBias" + ) + else: + if sp_enable: + model = ModelSPWithoutBias(hidden_size, dtype, eps, device="npu") + register_pattern_safe( + GraphEXAddRMSNormQuantSPPattern, vllm_config, eps, "GraphEXAddRMSNormQuantSPPattern" + ) + else: + model = ModelWithoutBias(hidden_size, dtype, eps, device="npu") + register_pattern_safe(GraphEXAddRMSNormQuantPattern, vllm_config, eps, "GraphEXAddRMSNormQuantPattern") + + model = model.to("npu") + x = torch.randn(num_tokens, hidden_size, device="npu", dtype=dtype, requires_grad=False) + + with torch.no_grad(): + original_optimize = torchair.npu_fx_compiler._optimize_fx + torchair.npu_fx_compiler._optimize_fx = create_pattern_wrapper( + lambda gm: assert_addrmsnorm_quant(gm, expect_fused=True, use_bias=use_bias, sp_enable=sp_enable) + ) + + compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True) + + compiled_out, compiled_res = compiled_model(x) + + torchair.npu_fx_compiler._optimize_fx = original_optimize diff --git a/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py new file mode 100644 index 00000000..f696ffbc --- /dev/null +++ b/tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py @@ -0,0 +1,222 @@ +import copy + +import pytest +import torch +import torch.nn as nn +import torchair +import vllm.config +from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment +from vllm.utils.system_utils import update_environment_variables + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.compilation.npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import ( + GraphEXQKNormRopeFusionPattern, + GraphEXQKNormRopeFusionPatternWithBias, +) +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + + +def find_op(gm, op_default): + return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes) + + +def create_pattern_wrapper(assert_func): + original_func = torchair.npu_fx_compiler._optimize_fx + + def wrapper(gm, example_inputs=None, config=None): + ret = original_func(gm, example_inputs, config) + graph_after = copy.deepcopy(gm) + assert_func(graph_after) + return ret + + return wrapper + + +@pytest.fixture(scope="module", autouse=True) +def init_triton(): + init_device_properties_triton() + + +class ModelQKNormRopeWithoutBias(nn.Module): + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + dtype: torch.dtype = torch.bfloat16, + eps: float = 1e-6, + device="npu", + ): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.eps = eps + + # RMSNorm weight per head (shared across heads of same type) + self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) + self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) + + def forward(self, qkv, cos, sin): + """ + Args: + qkv: [T, q_size + 2*kv_size] + cos: [1, T, 1, head_dim] + sin: [1, T, 1, head_dim] + Returns: + q_rope, k_rope, v + """ + # Split QKV + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q RMSNorm (per-head) + q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, self.q_weight, self.eps) + + # K RMSNorm (per-head) + k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, self.k_weight, self.eps) + + # Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim] + q_flat = q_norm_out.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + + k_flat = k_norm_out.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + # Apply RoPE + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v + + +class ModelQKNormRopeWithBias(nn.Module): + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + dtype: torch.dtype = torch.bfloat16, + eps: float = 1e-6, + device="npu", + ): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.eps = eps + + self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) + self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) + self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) + self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device)) + + def forward(self, qkv, cos, sin): + # Split QKV + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q RMSNorm + Bias + q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, self.q_weight, self.eps) + q_normed = q_norm_out + self.q_bias + + # K RMSNorm + Bias + k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, self.k_weight, self.eps) + k_normed = k_norm_out + self.k_bias + + # Reshape for RoPE + q_flat = q_normed.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + + k_flat = k_normed.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + # Apply RoPE + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v + + +def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False): + check_rules = [ + (torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused), + (torch.ops.npu.npu_rms_norm.default, not expect_fused), + (torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused), + ] + if use_bias: + check_rules.append((torch.ops.aten.add.Tensor, not expect_fused)) + for torch_op, expect_exist in check_rules: + found = find_op(after_gm, torch_op) + if expect_exist: + assert found, f"Expected operator '{torch_op}' but not find" + else: + assert not found, f"Not expected operator '{torch_op}' but find" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5]) +@pytest.mark.parametrize("use_bias", [False, True]) +def test_rmsnorm_quant_fusion( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + use_bias: bool, +): + vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) + with vllm.config.set_current_vllm_config(vllm_config): + update_environment_variables( + { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + init_distributed_environment() + ensure_model_parallel_initialized(1, 1) + num_heads = 16 + num_kv_heads = 8 + head_dim = 128 + with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config): + fusion_pattern = None + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + qkv_size = q_size + 2 * kv_size + if use_bias: + model = ModelQKNormRopeWithBias(head_dim, num_heads, num_kv_heads, dtype, eps, device="npu") + fusion_pattern = GraphEXQKNormRopeFusionPatternWithBias( + vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps + ) + else: + model = ModelQKNormRopeWithoutBias(head_dim, num_heads, num_kv_heads, dtype, eps, device="npu") + fusion_pattern = GraphEXQKNormRopeFusionPattern( + vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps + ) + fusion_pattern.register() + model = model.to("npu") + seq_len = 5 + qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype) + cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) + sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype) + + with torch.no_grad(): + original_optimize = torchair.npu_fx_compiler._optimize_fx + torchair.npu_fx_compiler._optimize_fx = create_pattern_wrapper( + lambda gm: assert_qknorm_rope_fusion(gm, expect_fused=True, use_bias=use_bias) + ) + + compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True) + + compiled_model(qkv, cos, sin) + + torchair.npu_fx_compiler._optimize_fx = original_optimize