[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from vllm.logger import logger
|
||||
|
||||
|
||||
class AddRMSNormQuantPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.debug("Quant fusion not enabled: unsupported dtype %s",
|
||||
dtype)
|
||||
logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
AddRMSNormQuantPattern(vllm_config,
|
||||
eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
@@ -17,8 +17,7 @@
|
||||
#
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
@@ -27,13 +26,7 @@ from vllm.logger import logger
|
||||
|
||||
|
||||
class QKNormRopeFusionPattern:
|
||||
|
||||
def __init__(self,
|
||||
vllm_config,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
eps=1e-6):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
@@ -45,65 +38,38 @@ class QKNormRopeFusionPattern:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
qkv = torch.empty(T,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
q_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
k_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
cos = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
sin = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos, sin]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
|
||||
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
|
||||
self.eps)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
|
||||
self.eps)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
|
||||
q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
def replacement(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -115,22 +81,16 @@ class QKNormRopeFusionPattern:
|
||||
q_bias=None,
|
||||
k_bias=None,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def __init__(self,
|
||||
vllm_config,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
eps=1e-6):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
qkv = torch.empty(T,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
q_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
k_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
sin = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
|
||||
self.eps)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
|
||||
q_normed = q_norm_out + q_bias
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
|
||||
self.eps)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
k_normed = k_norm_out + k_bias
|
||||
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
|
||||
q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias:
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
sin=sin,
|
||||
)
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class QKNormRopeFusionPass(VllmInductorPass):
|
||||
@@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qknorm_rope_fusion_pass")
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
|
||||
dtype)
|
||||
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
vllm_config, Attention)
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
if len(attn_layers) == 0:
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
for epsilon in [1e-6, 1e-5]:
|
||||
if layer.head_size != 128:
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
|
||||
layer.head_size)
|
||||
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
|
||||
continue
|
||||
QKNormRopeFusionPattern(vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon).register(
|
||||
self.pattern_match_passes)
|
||||
QKNormRopeFusionPattern(
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
).register(self.pattern_match_passes)
|
||||
|
||||
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon).register(
|
||||
self.pattern_match_passes)
|
||||
QKNormRopeFusionPatternWithBias(
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
Reference in New Issue
Block a user