[Graph][Fusion] Integrating inductor pass and npugraph ex pass (#6354)
### What this PR does / why we need it?
Integrating inductor pass and npugraph ex pass, see RFC:
https://github.com/vllm-project/vllm-ascend/issues/6347
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
all tests passed.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -16,12 +16,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.compilation.passes.base_pattern import BasePattern
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("v0.15.0"):
|
||||
@@ -32,15 +32,14 @@ else:
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
|
||||
|
||||
class QKNormRopeFusionPattern:
|
||||
class QKNormRopeFusionPattern(BasePattern):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
super().__init__(vllm_config, eps)
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.device = vllm_config.device_config.device if vllm_config.device_config else None
|
||||
|
||||
def get_inputs(self):
|
||||
@@ -53,7 +52,7 @@ class QKNormRopeFusionPattern:
|
||||
positions = torch.ones(T, dtype=torch.int64, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos_sin_cache, positions]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_pattern(self):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
@@ -77,6 +76,9 @@ class QKNormRopeFusionPattern:
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
return pattern
|
||||
|
||||
def get_replacement(self):
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
@@ -100,18 +102,17 @@ class QKNormRopeFusionPattern:
|
||||
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
return replacement
|
||||
|
||||
|
||||
class QKNormRopeFusionPatternWithBias:
|
||||
class QKNormRopeFusionPatternWithBias(BasePattern):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
super().__init__(vllm_config, eps)
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.vllm_config = vllm_config
|
||||
self.device = vllm_config.device_config.device if vllm_config.device_config else None
|
||||
|
||||
def get_inputs(self):
|
||||
@@ -127,7 +128,7 @@ class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos_sin_cache, positions]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_pattern(self):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
@@ -155,6 +156,9 @@ class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
return pattern
|
||||
|
||||
def get_replacement(self):
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
@@ -179,7 +183,7 @@ class QKNormRopeFusionPatternWithBias:
|
||||
)
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
return replacement
|
||||
|
||||
|
||||
class QKNormRopeFusionPass(VllmInductorPass):
|
||||
|
||||
Reference in New Issue
Block a user