################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from functools import wraps from typing import Callable, Optional import torch import torch_br from fastcore.basics import patch_to from torch_br.utils.tensor_methods import Sbp from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod) from vllm.model_executor.utils import set_weight_attrs from vllm_br import envs from ..br_utils import (_convert_to_crossed_numa_tensor, _convert_to_numa_tensor, align_n, cross_weight_32) from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn, fused_oss_moe_dyn) @patch_to(UnquantizedFusedMoEMethod) def forward_oot( self, layer: FusedMoE, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ): """Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support. """ if activation == "swigluoai": return fused_oss_moe_dyn( x, layer.w13_weight, layer.w13_bias, layer.w2_weight, layer.w2_bias, router_logits, top_k, layer.intermediate_size_per_partition, renormalize=renormalize, inplace=True, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, ep_rank=layer.ep_rank, ep_size=layer.ep_size) b_seq = x.shape[0] gating_weight, shared_gate_up_weight, shared_down_weight = router_logits if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN: # prefill return fused_moe_quant_dyn( x, shared_gate_up_weight, shared_down_weight, layer.w13_weight, layer.w2_weight, None, None, gating_weight, top_k, layer.intermediate_size_per_partition, renormalize=renormalize, inplace=True, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, tp_rank=get_tp_group().rank_in_group, global_rank=get_tp_group().rank, tp_size=get_tensor_model_parallel_world_size(), ep_rank=layer.ep_rank, ep_size=layer.ep_size) else: # decoder return fused_moe_quant_device( x, shared_gate_up_weight, shared_down_weight, layer.w13_weight, layer.w2_weight, None, None, gating_weight, top_k, layer.intermediate_size_per_partition, renormalize=renormalize, inplace=True, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, tp_rank=get_tp_group().rank_in_group, global_rank=get_tp_group().rank, tp_size=get_tensor_model_parallel_world_size(), ep_rank=layer.ep_rank, ep_size=layer.ep_size) @patch_to(UnquantizedFusedMoEMethod) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, device="cpu", dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) if self.moe.has_bias: w13_bias = torch.nn.Parameter(torch.zeros( num_experts, 2 * intermediate_size_per_partition, device="cpu", dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size_per_partition, device="cpu", dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) if self.moe.has_bias: w2_bias = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, device="cpu", dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) @patch_to(UnquantizedFusedMoEMethod) def process_weights_after_loading(self: UnquantizedFusedMoEMethod, layer: FusedMoE) -> None: cur_device = torch.supa.current_device() die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM die_num = 1 if die_spc_num <= 16 else 2 spc_num = die_spc_num // die_num align_size = 32 if layer.activation == "swigluoai" else 64 is_dual_die = (die_spc_num > 16) # NOTE: w13_weight # after _load_w13, w13_weight is a colparallel weight, shape # [num_experts, 2 * intermediate_size_per_partition, hidden_size] # for SUPA, transform it to a NUMA colmajor weight, shape # [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64)) wk = layer.hidden_size wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num, align_size=align_size, spc_num=spc_num) supa_w13_weight = torch_br._empty_ut_only( size=(die_spc_num * layer.local_num_experts, wk, wn_block), dtype=torch.bfloat16, is_numa=True, device=cur_device, tensor_type="colmajor", axis=0, sbp="SS" if is_dual_die else None) for expert_id in range(layer.local_num_experts): expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous() # swigluoai activation, no need do interweave if layer.activation and layer.activation == "swigluoai": pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size, 'COLMAJOR', expert_w13.dtype) pad_expert_w13_shape = pad_expert_w13.shape hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1] narrow_data = supa_w13_weight.view_as_usharp( "COLMAJOR", pad_expert_w13_shape, Sbp.ss(0), expert_id * hw_size) narrow_data.copy_(pad_expert_w13) else: expert_1, expert_3 = expert_w13.chunk(2, dim=1) pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1, expert_3, die_spc_num, dim=1, need_pad=True, layout='COLMAJOR') hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1] narrow_data = supa_w13_weight.view_as_usharp( "COLMAJOR", pad_expert_w13.shape, Sbp.ss(0), expert_id * hw_size) narrow_data.copy_(pad_expert_w13) layer.w13_weight.data = supa_w13_weight # NOTE: w13_bias if hasattr(layer, "w13_bias") and layer.w13_bias is not None: wn = layer.intermediate_size_per_partition * 2 supa_w13_bias = torch_br._empty_ut_only( size=(layer.local_num_experts, wn), dtype=torch.float32, is_numa=False, device=cur_device, tensor_type="linear_bias", sbp="BB" if is_dual_die else None) for expert_id in range(layer.local_num_experts): expert_w13_bias = layer.w13_bias[expert_id] # swigluoai activation, no need do interweave if layer.activation and layer.activation == "swigluoai": narrow_data = supa_w13_bias[expert_id] narrow_data.copy_(expert_w13_bias) else: expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1) crossed_expert_w13_bias = cross_weight_32( expert_1_bias, expert_3_bias, die_spc_num, dim=0, need_pad=False, ) narrow_data = supa_w13_bias[expert_id] narrow_data.copy_(crossed_expert_w13_bias) layer.w13_bias.data = supa_w13_bias # NOTE: w2_weight # after _load_w2, w2_weight is a rowparallel weight, shape # [num_experts, hidden_size, intermediate_size_per_partition] # for SUPA, transform it to a NUMA colmajor weight, shape # [spc_num * num_experts, wk, wn_block] align_size = 32 wk = layer.intermediate_size_per_partition wn_block = align_n(layer.hidden_size, align_size=align_size, spc_num=spc_num) supa_w2_weight = torch_br._empty_ut_only( size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block), dtype=torch.bfloat16, is_numa=True, device=cur_device, tensor_type="colmajor", axis=0, sbp="SS" if is_dual_die else None) for expert_id in range(layer.local_num_experts): expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous() pad_expert_w2 = _convert_to_numa_tensor(expert_w2, align_size, 'COLMAJOR', expert_w2.dtype, parallel_type="row_parallel") pad_expert_w2_shape = pad_expert_w2.shape hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1] narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR", pad_expert_w2_shape, Sbp.ss(0), expert_id * hw_size) narrow_data.copy_(pad_expert_w2) layer.w2_weight.data = supa_w2_weight # NOTE: w2_bias if hasattr(layer, "w2_bias") and layer.w2_bias is not None: wn = layer.hidden_size supa_w2_bias = torch.zeros((layer.local_num_experts, wn), dtype=torch.float32, device=cur_device) for expert_id in range(layer.local_num_experts): expert_w2 = layer.w2_bias[expert_id] narrow_data = supa_w2_bias[expert_id] narrow_data.copy_(expert_w2) layer.w2_bias.data = supa_w2_bias @patch_to(FusedMoE) def forward(self: FusedMoE, hidden_states: torch.Tensor, router_logits: torch.Tensor): """ ! router_logits is a tuple of gate, shared_experts.gate_up_proj, shared_experts.down_proj weights. """ assert self.quant_method is not None assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes' final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, ) # NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states support_types = ((16, 4), (16, 8), (32, 2), (32, 4)) tp_size = get_tensor_model_parallel_world_size() if hidden_states.shape[ 0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and ( envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types: final_hidden_states.all_reduced = True return final_hidden_states @patch_to(FusedMoE) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight.cpu()) @patch_to(FusedMoE) def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int, load_full: bool = False): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight.cpu()) def wrapper_FusedMoE_init(fn): @wraps(fn) def wrapper(self, *args, **kwargs): fn(self, *args, **kwargs) if self.e_score_correction_bias is not None: self.e_score_correction_bias.data = self.e_score_correction_bias.float( ) return wrapper FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501