Files
xc-llm-ascend/vllm_ascend/compilation/passes/allgather_chunk_noop_pass.py

41 lines
1.7 KiB
Python
Raw Permalink Normal View History

import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
self._register_patterns()
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
def _empty(self, *args, **kwargs):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def _register_patterns(self) -> None:
def pattern(input: torch.Tensor) -> torch.Tensor:
gathered = self._all_gather(input)
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
def replacement(input: torch.Tensor) -> torch.Tensor:
return input
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
def __call__(self, graph: torch.fx.Graph) -> None:
self.begin()
matched_count = self.patterns.apply(graph)
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
self.end_and_log()