Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic128Sym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(
|
||||
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
|
||||
)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
at2 = self.quant_matcher(at1)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
|
||||
rope_targets = [torch.ops._C.rotary_embedding.default]
|
||||
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
rope_targets.append(
|
||||
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
|
||||
)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
if at_target in rope_targets:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
Reference in New Issue
Block a user