# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import torch import torch._inductor.pattern_matcher as pm import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig from vllm.distributed import get_tp_group from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.logger import init_logger from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) class BasePattern: def __init__(self, dtype: torch.dtype, device: str): self.dtype = dtype self.device = device self.tp = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() class GEMMReduceScatterPattern(BasePattern): def get_inputs(self): mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [mul, mm_weight] def register(self, pm_pass: PatternMatcherPass): def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): mm = torch.ops.aten.mm.default(mul, mm_weight) reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name) return reduce_scatter def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( mul, mm_weight, "avg", scatter_dim=0, group_name=self.tp.device_group.group_name, ) return gemm_rs pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AllGatherGEMMPattern(BasePattern): def get_inputs(self): x = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [x, weight] def register(self, pm_pass: PatternMatcherPass): def pattern( x: torch.Tensor, weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_gather = torch.ops.vllm.all_gather.default( x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name) return torch.ops.aten.mm.default(all_gather, weight) def replacement( x: torch.Tensor, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, [weight], gather_dim=0, group_name=self.tp.device_group.group_name, ) return mm_outputs pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AsyncTPPass(VllmInductorPass): def __init__(self, config: VllmConfig): super().__init__(config) # Enable symmetric memory for the TP process group enable_symm_mem_for_group(get_tp_group().device_group.group_name) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="async_tp_pass") GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 def __call__(self, graph: fx.Graph): self.begin() self.dump_graph(graph, "before_async_tp_pass") count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_async_tp_pass") self.end_and_log()