Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
"sum",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
"avg",
"sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
"avg",
"sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,

View File

@@ -5,7 +5,6 @@ import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic128Sym,
)
from vllm.platforms import current_platform
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
logger.debug(
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
)
def uuid(self) -> str:
fusion_patterns = [
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload) -> None:
def __init__(self) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
self.quant_matcher = MatcherQuantFP8(
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
)
def get_inputs(self) -> list[torch.Tensor]:
return [
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
at2 = self.quant_matcher(at1)
return at2[0], at2[1]
def replacement(
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
self.dump_patterns(config, self.patterns)

View File

@@ -18,7 +18,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,