Files
2026-03-10 13:31:25 +08:00

171 lines
7.3 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch_br
from torch import nn
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_pp_group,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm_br import envs
from vllm_br.utils import get_grandparent_pid
class LlamaMlpSiluL3(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj")
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
x = torch_br.supa_silumul(gate, up)
x, _ = self.down_proj(x)
return x
class MergedGateUpMLPSiluL2(nn.Module):
"""
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size = intermediate_size
self.prefix = prefix
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.gate_up_proj.has_cross_weight = True
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
self, "grandparent_pid"):
self.grandparent_pid = get_grandparent_pid()
if "shared_experts" not in self.prefix:
quant_flag = hasattr(self.gate_up_proj, "qweight")
hidden_size = x.shape[-1]
seq_len = x.shape[-2]
gu_weight = self.gate_up_proj.qweight if quant_flag else self.gate_up_proj.weight
gu_scales = self.gate_up_proj.scales if quant_flag else None
gate_up_output = torch_br.br_fused_mlp_infer(
x, [gu_weight],
output_w=self.intermediate_size // self.tp_size,
scales=[gu_scales] if gu_scales is not None else None,
activation_mode="act_swiglu")
down_weight = self.down_proj.qweight if quant_flag else self.down_proj.weight
down_scales = self.down_proj.scales if quant_flag else None
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pp_group().world_size
all_rank = self.tp_size * pp_size
support_types = ((16, 4), (32, 2), (32, 4))
if all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and \
(envs.VLLM_BR_DEVICE_SPC_NUM, self.tp_size) in support_types:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % self.tp_size
assert rank_i == tp_rank
down_output = torch_br.supa_fused_linear_allreduce_opt(
gate_up_output,
down_weight,
hidden_size,
tp_rank,
self.tp_size,
global_rank,
0,
scales=down_scales)
return down_output
else:
down_output = torch_br.br_fused_mlp_infer(
gate_up_output, [down_weight],
output_w=hidden_size,
scales=[down_scales] if down_scales is not None else None)
if self.tp_size > 1:
out = down_output
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and out.shape[
1] <= 32:
tp_rank = get_tensor_model_parallel_rank()
output = torch_br.supa_allreduce_pcie_infer(
out, tp_rank, self.tp_size, self.grandparent_pid)
else:
output = tensor_model_parallel_all_reduce(out)
return output
else:
return down_output
else:
return self.gate_up_proj.weight, self.down_proj.weight