# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch._inductor.pattern_matcher as pm import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger from vllm.platforms import current_platform from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) class BasePattern: def __init__(self, dtype: torch.dtype, device: str | None) -> None: self.dtype = dtype self.device = device self.tp = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() class GEMMReduceScatterPattern(BasePattern): def get_inputs(self) -> list[torch.Tensor]: mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [mul, mm_weight] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: mm = torch.ops.aten.mm.default(mul, mm_weight) reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name, ) return reduce_scatter def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( mul, mm_weight, "avg", scatter_dim=0, group_name=self.tp.device_group.group_name, ) return gemm_rs pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class AllGatherGEMMPattern(BasePattern): def get_inputs(self) -> list[torch.Tensor]: x = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [x, weight] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( x: torch.Tensor, weight: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name, ) return torch.ops.aten.mm.default(all_gather, weight) def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, [weight], gather_dim=0, group_name=self.tp.device_group.group_name, ) return mm_outputs pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class ScaledMMReduceScatterPattern(BasePattern): def get_inputs(self) -> list[torch.Tensor]: input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) mm_weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) .contiguous() .transpose(0, 1) ) scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) return [input, mm_weight, scale_a, scale_b] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: scaled_mm = torch.ops.aten._scaled_mm.default( input, mat2=mat2, scale_a=scale_a, scale_b=scale_b, bias=None, scale_result=None, out_dtype=self.dtype, ) reduce_scatter = torch.ops.vllm.reduce_scatter.default( scaled_mm, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name, ) return reduce_scatter def replacement( input: torch.Tensor, mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: # Calculate output shape: input @ mat2 with scatter_dim reduced output_shape = [*input.shape[:-1], mat2.shape[1]] scatter_dim = 0 gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", scatter_dim, # orig_scatter_dim scatter_dim, # scatter_dim_after_maybe_reshape self.tp.device_group.group_name, output_shape, None, # bias None, # result_scale self.dtype, # out_dtype False, # use_fast_accum ) return gemm_rs pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class AllGatherScaledMMPattern(BasePattern): def get_inputs(self) -> list[torch.Tensor]: x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) .contiguous() .transpose(0, 1) ) s1 = x.shape[0] * self.tp_size scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) return [x, weight, scale_a, scale_b] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( x: torch.Tensor, weight: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name ) return torch.ops.aten._scaled_mm.default( all_gather, mat2=weight, scale_a=scale_a, scale_b=scale_b, bias=None, scale_result=None, out_dtype=self.dtype, ) def replacement( x: torch.Tensor, weight: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa x, [weight], scale_a, [scale_b], gather_dim=0, biases=[None], result_scales=[None], out_dtypes=[self.dtype], use_fast_accum=[False], group_name=self.tp.device_group.group_name, ) return mm_outputs pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class CutlassScaledMMReduceScatterPattern(BasePattern): def get_inputs(self) -> list[torch.Tensor]: input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) mm_weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) .contiguous() .transpose(0, 1) ) scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype) return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( torch.ops._C.cutlass_scaled_mm.default, out=cutlass_mm_output, a=input, b=weight, a_scales=scale_a, b_scales=scale_b, bias=None, ) reduce_scatter = torch.ops.vllm.reduce_scatter.default( cutlass_scaled_mm[1], dim=0, world_size=self.tp_size, group_name=self.tp.unique_name, ) return reduce_scatter def replacement( input: torch.Tensor, mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: # Calculate output shape: input @ mat2 with scatter_dim reduced output_shape = [*input.shape[:-1], mat2.shape[1]] scatter_dim = 0 gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", scatter_dim, # orig_scatter_dim scatter_dim, # scatter_dim_after_maybe_reshape self.tp.device_group.group_name, output_shape, None, # bias None, # result_scale self.dtype, # out_dtype False, # use_fast_accum ) return gemm_rs pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class AllGatherCutlassScaledMMPattern(BasePattern): def get_inputs(self) -> list[torch.Tensor]: x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) .contiguous() .transpose(0, 1) ) s1 = x.shape[0] * self.tp_size scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) s2 = weight.shape[1] output = torch.empty([s1, s2], device=self.device, dtype=self.dtype) return [x, weight, scale_a, scale_b, output] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( x: torch.Tensor, weight: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, output: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name ) cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( torch.ops._C.cutlass_scaled_mm.default, out=output, a=all_gather, b=weight, a_scales=scale_a, b_scales=scale_b, bias=None, ) return cutlass_scaled_mm[1] def replacement( x: torch.Tensor, weight: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, output: torch.Tensor, ) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa x, [weight], scale_a, [scale_b], gather_dim=0, biases=[None], result_scales=[None], out_dtypes=[self.dtype], use_fast_accum=[False], group_name=self.tp.device_group.group_name, ) return mm_outputs pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class AsyncTPPass(VllmPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig) -> None: super().__init__(config) # Enable symmetric memory for the TP process group enable_symm_mem_for_group(get_tp_group().device_group.group_name) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="async_tp_pass" ) GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) # These fusions are enabled only for bfloat16 models because # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling # only supports bfloat16 as the output dtype. if self.model_dtype == torch.bfloat16: ScaledMMReduceScatterPattern(self.model_dtype, self.device).register( self.patterns ) AllGatherScaledMMPattern(self.model_dtype, self.device).register( self.patterns ) CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register( self.patterns ) AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register( self.patterns ) self.dump_patterns(config, self.patterns) def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. if ( not self.compilation_config.splitting_ops or self.compilation_config.use_inductor_graph_partition ): return True tp_size = get_tensor_model_parallel_world_size() return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count)