41 lines
1.7 KiB
Python
41 lines
1.7 KiB
Python
|
|
import torch
|
||
|
|
import torch._inductor.pattern_matcher as pm
|
||
|
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||
|
|
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||
|
|
from vllm.config import VllmConfig
|
||
|
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||
|
|
from vllm.logger import logger
|
||
|
|
|
||
|
|
|
||
|
|
class AllGatherChunkNoOpCleanupPass(VllmInductorPass):
|
||
|
|
"""Fold all_gather + sequence_parallel_chunk_impl into identity."""
|
||
|
|
|
||
|
|
def __init__(self, config: VllmConfig):
|
||
|
|
super().__init__(config)
|
||
|
|
self.tp_group = get_tp_group()
|
||
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||
|
|
self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="npu_allgather_chunk_noop_cleanup_pass")
|
||
|
|
self._register_patterns()
|
||
|
|
|
||
|
|
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
|
return torch.ops.vllm.all_gather(x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name)
|
||
|
|
|
||
|
|
def _empty(self, *args, **kwargs):
|
||
|
|
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||
|
|
|
||
|
|
def _register_patterns(self) -> None:
|
||
|
|
def pattern(input: torch.Tensor) -> torch.Tensor:
|
||
|
|
gathered = self._all_gather(input)
|
||
|
|
return torch.ops.vllm.sequence_parallel_chunk_impl(gathered)
|
||
|
|
|
||
|
|
def replacement(input: torch.Tensor) -> torch.Tensor:
|
||
|
|
return input
|
||
|
|
|
||
|
|
pm.register_replacement(pattern, replacement, [self._empty(8, 16)], pm.fwd_only, self.patterns)
|
||
|
|
|
||
|
|
def __call__(self, graph: torch.fx.Graph) -> None:
|
||
|
|
self.begin()
|
||
|
|
matched_count = self.patterns.apply(graph)
|
||
|
|
logger.debug("AllGatherChunkNoOpCleanupPass replaced %s patterns", matched_count)
|
||
|
|
self.end_and_log()
|