[ST]Add e2e test for Npugraphex_pass (#6388)
### What this PR does / why we need it?
We found the custom passes of NPUGraphEX have implemented fusion
operator features, which still require E2E test case validation and
guard. This PR implements E2E test cases for the AddRMSNormQuant and
SplitQKVNormRope operator fusions under NPUGraphEX that are already in
the codebase.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: cjian <2318164299@qq.com>
This commit is contained in:
4
.github/workflows/scripts/config.yaml
vendored
4
.github/workflows/scripts/config.yaml
vendored
@@ -1,4 +1,8 @@
|
||||
e2e-singlecard:
|
||||
- name: tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py
|
||||
estimated_time: 80
|
||||
- name: tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py
|
||||
estimated_time: 80
|
||||
- name: tests/e2e/singlecard/test_auto_fit_max_mode_len.py
|
||||
estimated_time: 25
|
||||
- name: tests/e2e/singlecard/test_aclgraph_accuracy.py
|
||||
|
||||
290
tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py
Normal file
290
tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import torchair
|
||||
import vllm.config
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.graphex_norm_quant_fusion_pass import (
|
||||
GraphEXAddRMSNormQuantPattern,
|
||||
GraphEXAddRMSNormQuantPatternWithBias,
|
||||
GraphEXAddRMSNormQuantSPPattern,
|
||||
GraphEXAddRMSNormQuantSPPatternWithBias,
|
||||
)
|
||||
|
||||
|
||||
def find_op(gm, op_default):
|
||||
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
|
||||
|
||||
|
||||
def create_pattern_wrapper(assert_func):
|
||||
original_func = torchair.npu_fx_compiler._optimize_fx
|
||||
|
||||
def wrapper(gm, example_inputs=None, config=None):
|
||||
ret = original_func(gm, example_inputs, config)
|
||||
graph_after = copy.deepcopy(gm)
|
||||
assert_func(graph_after)
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ModelWithoutBias(nn.Module):
|
||||
"""
|
||||
A minimal test model that simulates the pattern:
|
||||
AddRMSNorm → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass:
|
||||
1. Perform npu_add_rms_norm
|
||||
2. Quantize the normalized output to int8
|
||||
Returns both quantized output and updated residual.
|
||||
"""
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
|
||||
class ModelWithBias(nn.Module):
|
||||
"""
|
||||
A test model that simulates the pattern:
|
||||
AddRMSNorm → Add Bias → Quantization (with bias)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
|
||||
self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass:
|
||||
1. Perform npu_add_rms_norm
|
||||
2. Add bias
|
||||
3. Quantize to int8
|
||||
Returns both quantized output and updated residual.
|
||||
"""
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
# Add bias
|
||||
norm_output_with_bias = norm_output + self.bias
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
|
||||
class ModelSPWithoutBias(nn.Module):
|
||||
"""
|
||||
A minimal test model that simulates the pattern:
|
||||
AddRMSNorm → maybe_allgather → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass:
|
||||
1. Perform npu_add_rms_norm
|
||||
2. Perform a fake maybe_all_gather_and_maybe_unpad
|
||||
3. Quantize the normalized output to int8
|
||||
Returns both quantized output and updated residual.
|
||||
"""
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output, True)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
|
||||
class ModelSPWithBias(nn.Module):
|
||||
"""
|
||||
A minimal test model that simulates the pattern:
|
||||
AddRMSNorm → Add bias → maybe_allgather → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.bfloat16, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
|
||||
self.bias = nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass:
|
||||
1. Perform npu_add_rms_norm
|
||||
2. Add bias
|
||||
3. Perform a fake maybe_all_gather_and_maybe_unpad
|
||||
4. Quantize the normalized output to int8
|
||||
Returns both quantized output and updated residual.
|
||||
"""
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
# Add bias
|
||||
norm_output_with_bias = norm_output + self.bias
|
||||
|
||||
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output_with_bias, True)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
|
||||
def assert_addrmsnorm_quant(after_gm, expect_fused=True, use_bias=False, sp_enable=False):
|
||||
check_rules = [
|
||||
(torch.ops.npu.npu_add_rms_norm_quant.default, expect_fused),
|
||||
(torch.ops.npu.npu_add_rms_norm.default, not expect_fused),
|
||||
(torch.ops.npu.npu_quantize.default, not expect_fused),
|
||||
]
|
||||
if use_bias:
|
||||
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
|
||||
if sp_enable:
|
||||
check_rules.append((torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, expect_fused))
|
||||
for torch_op, expect_exist in check_rules:
|
||||
found = find_op(after_gm, torch_op)
|
||||
if expect_exist:
|
||||
assert found, f"Expected operator '{torch_op}' but not find"
|
||||
else:
|
||||
assert not found, f"Not expected operator '{torch_op}' but find"
|
||||
|
||||
|
||||
_registered_patterns = set()
|
||||
|
||||
|
||||
def register_pattern_safe(pattern_class, vllm_config, eps, pattern_key):
|
||||
global _registered_patterns
|
||||
if pattern_key in _registered_patterns:
|
||||
print(f"Pattern {pattern_key} already registered, skipping...")
|
||||
return None
|
||||
|
||||
pattern = pattern_class(vllm_config=vllm_config, eps=eps)
|
||||
try:
|
||||
pattern.register()
|
||||
_registered_patterns.add(pattern_key)
|
||||
print(f"Successfully registered pattern: {pattern_key}")
|
||||
except RuntimeError as e:
|
||||
if "Duplicate pattern" in str(e):
|
||||
print(f"Pattern {pattern_key} already exists (caught from RuntimeError), skipping...")
|
||||
_registered_patterns.add(pattern_key)
|
||||
else:
|
||||
raise e
|
||||
return pattern
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5])
|
||||
@pytest.mark.parametrize("use_bias", [False, True])
|
||||
@pytest.mark.parametrize("sp_enable", [False, True])
|
||||
def test_rmsnorm_quant_fusion(
|
||||
dtype: torch.dtype,
|
||||
hidden_size: int,
|
||||
num_tokens: int,
|
||||
eps: float,
|
||||
use_bias: bool,
|
||||
sp_enable: bool,
|
||||
):
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
init_distributed_environment()
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config):
|
||||
if use_bias:
|
||||
if sp_enable:
|
||||
model = ModelSPWithBias(hidden_size, dtype, eps, device="npu")
|
||||
register_pattern_safe(
|
||||
GraphEXAddRMSNormQuantSPPatternWithBias, vllm_config, eps, "GraphEXAddRMSNormQuantSPPatternWithBias"
|
||||
)
|
||||
else:
|
||||
model = ModelWithBias(hidden_size, dtype, eps, device="npu")
|
||||
register_pattern_safe(
|
||||
GraphEXAddRMSNormQuantPatternWithBias, vllm_config, eps, "GraphEXAddRMSNormQuantPatternWithBias"
|
||||
)
|
||||
else:
|
||||
if sp_enable:
|
||||
model = ModelSPWithoutBias(hidden_size, dtype, eps, device="npu")
|
||||
register_pattern_safe(
|
||||
GraphEXAddRMSNormQuantSPPattern, vllm_config, eps, "GraphEXAddRMSNormQuantSPPattern"
|
||||
)
|
||||
else:
|
||||
model = ModelWithoutBias(hidden_size, dtype, eps, device="npu")
|
||||
register_pattern_safe(GraphEXAddRMSNormQuantPattern, vllm_config, eps, "GraphEXAddRMSNormQuantPattern")
|
||||
|
||||
model = model.to("npu")
|
||||
x = torch.randn(num_tokens, hidden_size, device="npu", dtype=dtype, requires_grad=False)
|
||||
|
||||
with torch.no_grad():
|
||||
original_optimize = torchair.npu_fx_compiler._optimize_fx
|
||||
torchair.npu_fx_compiler._optimize_fx = create_pattern_wrapper(
|
||||
lambda gm: assert_addrmsnorm_quant(gm, expect_fused=True, use_bias=use_bias, sp_enable=sp_enable)
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
|
||||
|
||||
compiled_out, compiled_res = compiled_model(x)
|
||||
|
||||
torchair.npu_fx_compiler._optimize_fx = original_optimize
|
||||
222
tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py
Normal file
222
tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchair
|
||||
import vllm.config
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.compilation.npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import (
|
||||
GraphEXQKNormRopeFusionPattern,
|
||||
GraphEXQKNormRopeFusionPatternWithBias,
|
||||
)
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
|
||||
|
||||
def find_op(gm, op_default):
|
||||
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
|
||||
|
||||
|
||||
def create_pattern_wrapper(assert_func):
|
||||
original_func = torchair.npu_fx_compiler._optimize_fx
|
||||
|
||||
def wrapper(gm, example_inputs=None, config=None):
|
||||
ret = original_func(gm, example_inputs, config)
|
||||
graph_after = copy.deepcopy(gm)
|
||||
assert_func(graph_after)
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def init_triton():
|
||||
init_device_properties_triton()
|
||||
|
||||
|
||||
class ModelQKNormRopeWithoutBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
eps: float = 1e-6,
|
||||
device="npu",
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = num_heads * head_dim
|
||||
self.kv_size = num_kv_heads * head_dim
|
||||
self.eps = eps
|
||||
|
||||
# RMSNorm weight per head (shared across heads of same type)
|
||||
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, qkv, cos, sin):
|
||||
"""
|
||||
Args:
|
||||
qkv: [T, q_size + 2*kv_size]
|
||||
cos: [1, T, 1, head_dim]
|
||||
sin: [1, T, 1, head_dim]
|
||||
Returns:
|
||||
q_rope, k_rope, v
|
||||
"""
|
||||
# Split QKV
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q RMSNorm (per-head)
|
||||
q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, self.q_weight, self.eps)
|
||||
|
||||
# K RMSNorm (per-head)
|
||||
k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, self.k_weight, self.eps)
|
||||
|
||||
# Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim]
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
|
||||
class ModelQKNormRopeWithBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
eps: float = 1e-6,
|
||||
device="npu",
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = num_heads * head_dim
|
||||
self.kv_size = num_kv_heads * head_dim
|
||||
self.eps = eps
|
||||
|
||||
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, qkv, cos, sin):
|
||||
# Split QKV
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q RMSNorm + Bias
|
||||
q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, self.q_weight, self.eps)
|
||||
q_normed = q_norm_out + self.q_bias
|
||||
|
||||
# K RMSNorm + Bias
|
||||
k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, self.k_weight, self.eps)
|
||||
k_normed = k_norm_out + self.k_bias
|
||||
|
||||
# Reshape for RoPE
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
|
||||
def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False):
|
||||
check_rules = [
|
||||
(torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused),
|
||||
(torch.ops.npu.npu_rms_norm.default, not expect_fused),
|
||||
(torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused),
|
||||
]
|
||||
if use_bias:
|
||||
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
|
||||
for torch_op, expect_exist in check_rules:
|
||||
found = find_op(after_gm, torch_op)
|
||||
if expect_exist:
|
||||
assert found, f"Expected operator '{torch_op}' but not find"
|
||||
else:
|
||||
assert not found, f"Not expected operator '{torch_op}' but find"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5])
|
||||
@pytest.mark.parametrize("use_bias", [False, True])
|
||||
def test_rmsnorm_quant_fusion(
|
||||
dtype: torch.dtype,
|
||||
hidden_size: int,
|
||||
num_tokens: int,
|
||||
eps: float,
|
||||
use_bias: bool,
|
||||
):
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
init_distributed_environment()
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
num_heads = 16
|
||||
num_kv_heads = 8
|
||||
head_dim = 128
|
||||
with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config):
|
||||
fusion_pattern = None
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
qkv_size = q_size + 2 * kv_size
|
||||
if use_bias:
|
||||
model = ModelQKNormRopeWithBias(head_dim, num_heads, num_kv_heads, dtype, eps, device="npu")
|
||||
fusion_pattern = GraphEXQKNormRopeFusionPatternWithBias(
|
||||
vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps
|
||||
)
|
||||
else:
|
||||
model = ModelQKNormRopeWithoutBias(head_dim, num_heads, num_kv_heads, dtype, eps, device="npu")
|
||||
fusion_pattern = GraphEXQKNormRopeFusionPattern(
|
||||
vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps
|
||||
)
|
||||
fusion_pattern.register()
|
||||
model = model.to("npu")
|
||||
seq_len = 5
|
||||
qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype)
|
||||
cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
|
||||
sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
original_optimize = torchair.npu_fx_compiler._optimize_fx
|
||||
torchair.npu_fx_compiler._optimize_fx = create_pattern_wrapper(
|
||||
lambda gm: assert_qknorm_rope_fusion(gm, expect_fused=True, use_bias=use_bias)
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
|
||||
|
||||
compiled_model(qkv, cos, sin)
|
||||
|
||||
torchair.npu_fx_compiler._optimize_fx = original_optimize
|
||||
Reference in New Issue
Block a user