[ST]Add e2e test for Npugraphex_pass (#6388)

### What this PR does / why we need it?
We found the custom passes of NPUGraphEX have implemented fusion
operator features, which still require E2E test case validation and
guard. This PR implements E2E test cases for the AddRMSNormQuant and
SplitQKVNormRope operator fusions under NPUGraphEX that are already in
the codebase.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: cjian <2318164299@qq.com>
This commit is contained in:
CodeCat
2026-01-30 09:14:07 +08:00
committed by GitHub
parent 4970de4242
commit b2857de43f
3 changed files with 516 additions and 0 deletions

View File

@@ -1,4 +1,8 @@
e2e-singlecard:
- name: tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py
estimated_time: 80
- name: tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py
estimated_time: 80
- name: tests/e2e/singlecard/test_auto_fit_max_mode_len.py
estimated_time: 25
- name: tests/e2e/singlecard/test_aclgraph_accuracy.py

View File

@@ -0,0 +1,290 @@
import copy
import pytest
import torch
import torch.nn as nn
import torch_npu
import torchair
import vllm.config
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
from vllm.utils.system_utils import update_environment_variables
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.compilation.npugraph_ex_passes.graphex_norm_quant_fusion_pass import (
GraphEXAddRMSNormQuantPattern,
GraphEXAddRMSNormQuantPatternWithBias,
GraphEXAddRMSNormQuantSPPattern,
GraphEXAddRMSNormQuantSPPatternWithBias,
)
def find_op(gm, op_default):
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
def create_pattern_wrapper(assert_func):
original_func = torchair.npu_fx_compiler._optimize_fx
def wrapper(gm, example_inputs=None, config=None):
ret = original_func(gm, example_inputs, config)
graph_after = copy.deepcopy(gm)
assert_func(graph_after)
return ret
return wrapper
class ModelWithoutBias(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → Quantization (without bias)
"""
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
quantized_output = torch.ops.vllm.quantize(
norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
)
return quantized_output, new_residual
class ModelWithBias(nn.Module):
"""
A test model that simulates the pattern:
AddRMSNorm → Add Bias → Quantization (with bias)
"""
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Add bias
3. Quantize to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
# Add bias
norm_output_with_bias = norm_output + self.bias
quantized_output = torch.ops.vllm.quantize(
norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
)
return quantized_output, new_residual
class ModelSPWithoutBias(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → maybe_allgather → Quantization (without bias)
"""
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Perform a fake maybe_all_gather_and_maybe_unpad
3. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output, True)
quantized_output = torch.ops.vllm.quantize(
norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
)
return quantized_output, new_residual
class ModelSPWithBias(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → Add bias → maybe_allgather → Quantization (without bias)
"""
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Add bias
3. Perform a fake maybe_all_gather_and_maybe_unpad
4. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
# Add bias
norm_output_with_bias = norm_output + self.bias
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output_with_bias, True)
quantized_output = torch.ops.vllm.quantize(
norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
)
return quantized_output, new_residual
def assert_addrmsnorm_quant(after_gm, expect_fused=True, use_bias=False, sp_enable=False):
check_rules = [
(torch.ops.npu.npu_add_rms_norm_quant.default, expect_fused),
(torch.ops.npu.npu_add_rms_norm.default, not expect_fused),
(torch.ops.npu.npu_quantize.default, not expect_fused),
]
if use_bias:
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
if sp_enable:
check_rules.append((torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, expect_fused))
for torch_op, expect_exist in check_rules:
found = find_op(after_gm, torch_op)
if expect_exist:
assert found, f"Expected operator '{torch_op}' but not find"
else:
assert not found, f"Not expected operator '{torch_op}' but find"
_registered_patterns = set()
def register_pattern_safe(pattern_class, vllm_config, eps, pattern_key):
global _registered_patterns
if pattern_key in _registered_patterns:
print(f"Pattern {pattern_key} already registered, skipping...")
return None
pattern = pattern_class(vllm_config=vllm_config, eps=eps)
try:
pattern.register()
_registered_patterns.add(pattern_key)
print(f"Successfully registered pattern: {pattern_key}")
except RuntimeError as e:
if "Duplicate pattern" in str(e):
print(f"Pattern {pattern_key} already exists (caught from RuntimeError), skipping...")
_registered_patterns.add(pattern_key)
else:
raise e
return pattern
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5])
@pytest.mark.parametrize("use_bias", [False, True])
@pytest.mark.parametrize("sp_enable", [False, True])
def test_rmsnorm_quant_fusion(
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
eps: float,
use_bias: bool,
sp_enable: bool,
):
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
with vllm.config.set_current_vllm_config(vllm_config):
update_environment_variables(
{
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
init_distributed_environment()
ensure_model_parallel_initialized(1, 1)
with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config):
if use_bias:
if sp_enable:
model = ModelSPWithBias(hidden_size, dtype, eps, device="npu")
register_pattern_safe(
GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config, eps, "GraphEXAddRMSNormQuantSPPatternWithBias"
)
else:
model = ModelWithBias(hidden_size, dtype, eps, device="npu")
register_pattern_safe(
GraphEXAddRMSNormQuantPatternWithBias, vllm_config, eps, "GraphEXAddRMSNormQuantPatternWithBias"
)
else:
if sp_enable:
model = ModelSPWithoutBias(hidden_size, dtype, eps, device="npu")
register_pattern_safe(
GraphEXAddRMSNormQuantSPPattern, vllm_config, eps, "GraphEXAddRMSNormQuantSPPattern"
)
else:
model = ModelWithoutBias(hidden_size, dtype, eps, device="npu")
register_pattern_safe(GraphEXAddRMSNormQuantPattern, vllm_config, eps, "GraphEXAddRMSNormQuantPattern")
model = model.to("npu")
x = torch.randn(num_tokens, hidden_size, device="npu", dtype=dtype, requires_grad=False)
with torch.no_grad():
original_optimize = torchair.npu_fx_compiler._optimize_fx
torchair.npu_fx_compiler._optimize_fx = create_pattern_wrapper(
lambda gm: assert_addrmsnorm_quant(gm, expect_fused=True, use_bias=use_bias, sp_enable=sp_enable)
)
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
compiled_out, compiled_res = compiled_model(x)
torchair.npu_fx_compiler._optimize_fx = original_optimize

View File

@@ -0,0 +1,222 @@
import copy
import pytest
import torch
import torch.nn as nn
import torchair
import vllm.config
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
from vllm.utils.system_utils import update_environment_variables
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.compilation.npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import (
GraphEXQKNormRopeFusionPattern,
GraphEXQKNormRopeFusionPatternWithBias,
)
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
def find_op(gm, op_default):
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
def create_pattern_wrapper(assert_func):
original_func = torchair.npu_fx_compiler._optimize_fx
def wrapper(gm, example_inputs=None, config=None):
ret = original_func(gm, example_inputs, config)
graph_after = copy.deepcopy(gm)
assert_func(graph_after)
return ret
return wrapper
@pytest.fixture(scope="module", autouse=True)
def init_triton():
init_device_properties_triton()
class ModelQKNormRopeWithoutBias(nn.Module):
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
dtype: torch.dtype = torch.bfloat16,
eps: float = 1e-6,
device="npu",
):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
self.eps = eps
# RMSNorm weight per head (shared across heads of same type)
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
def forward(self, qkv, cos, sin):
"""
Args:
qkv: [T, q_size + 2*kv_size]
cos: [1, T, 1, head_dim]
sin: [1, T, 1, head_dim]
Returns:
q_rope, k_rope, v
"""
# Split QKV
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q RMSNorm (per-head)
q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, self.q_weight, self.eps)
# K RMSNorm (per-head)
k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, self.k_weight, self.eps)
# Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim]
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
# Apply RoPE
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
class ModelQKNormRopeWithBias(nn.Module):
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
dtype: torch.dtype = torch.bfloat16,
eps: float = 1e-6,
device="npu",
):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
self.eps = eps
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
def forward(self, qkv, cos, sin):
# Split QKV
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q RMSNorm + Bias
q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, self.q_weight, self.eps)
q_normed = q_norm_out + self.q_bias
# K RMSNorm + Bias
k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, self.k_weight, self.eps)
k_normed = k_norm_out + self.k_bias
# Reshape for RoPE
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
# Apply RoPE
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False):
check_rules = [
(torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused),
(torch.ops.npu.npu_rms_norm.default, not expect_fused),
(torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused),
]
if use_bias:
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
for torch_op, expect_exist in check_rules:
found = find_op(after_gm, torch_op)
if expect_exist:
assert found, f"Expected operator '{torch_op}' but not find"
else:
assert not found, f"Not expected operator '{torch_op}' but find"
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5])
@pytest.mark.parametrize("use_bias", [False, True])
def test_rmsnorm_quant_fusion(
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
eps: float,
use_bias: bool,
):
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
with vllm.config.set_current_vllm_config(vllm_config):
update_environment_variables(
{
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
init_distributed_environment()
ensure_model_parallel_initialized(1, 1)
num_heads = 16
num_kv_heads = 8
head_dim = 128
with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config):
fusion_pattern = None
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
qkv_size = q_size + 2 * kv_size
if use_bias:
model = ModelQKNormRopeWithBias(head_dim, num_heads, num_kv_heads, dtype, eps, device="npu")
fusion_pattern = GraphEXQKNormRopeFusionPatternWithBias(
vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps
)
else:
model = ModelQKNormRopeWithoutBias(head_dim, num_heads, num_kv_heads, dtype, eps, device="npu")
fusion_pattern = GraphEXQKNormRopeFusionPattern(
vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps
)
fusion_pattern.register()
model = model.to("npu")
seq_len = 5
qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype)
cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
with torch.no_grad():
original_optimize = torchair.npu_fx_compiler._optimize_fx
torchair.npu_fx_compiler._optimize_fx = create_pattern_wrapper(
lambda gm: assert_qknorm_rope_fusion(gm, expect_fused=True, use_bias=use_bias)
)
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
compiled_model(qkv, cos, sin)
torchair.npu_fx_compiler._optimize_fx = original_optimize