[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
127
vllm/compilation/collective_fusion.py
Normal file
127
vllm/compilation/collective_fusion.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
|
||||
def __init__(self, dtype: torch.dtype, device: str):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass")
|
||||
GEMMReduceScatterPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_async_tp_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_async_tp_pass")
|
||||
self.end_and_log()
|
||||
Reference in New Issue
Block a user