1272 lines
60 KiB
Python
1272 lines
60 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
||
"""Inference-only MOE model."""
|
||
from typing import Any, List, Optional, Dict, Tuple
|
||
from dataclasses import dataclass
|
||
|
||
import torch
|
||
from torch import nn
|
||
|
||
from vllm.config import get_current_vllm_config
|
||
from vllm.distributed import (
|
||
get_moe_tensor_parallel_rank,
|
||
get_moe_tensor_parallel_world_size,
|
||
get_moe_tensor_parallel_group,
|
||
get_moe_expert_parallel_rank,
|
||
get_moe_expert_parallel_world_size,
|
||
get_moe_expert_parallel_group,
|
||
get_tensor_model_parallel_rank,
|
||
get_tensor_model_parallel_world_size,
|
||
get_tp_group,
|
||
get_dp_group,
|
||
divide,
|
||
)
|
||
from vllm.distributed.utils import divide
|
||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_grouped_topk
|
||
from vllm.utils.torch_utils import get_dtype_size
|
||
from vllm.model_executor.layers.batch_invariant import (
|
||
vllm_is_batch_invariant,
|
||
)
|
||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||
from vllm.platforms import current_platform
|
||
|
||
from vllm_mlu import _mlu_ops as mlu_ops
|
||
from vllm_mlu._mlu_utils import *
|
||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
|
||
from vllm_mlu.distributed.parallel_state import(
|
||
CnclEP, cnclep_dispatch, cnclep_combine)
|
||
|
||
from vllm_mlu.distributed.parallel_state import(
|
||
CnclEP, cnclep_dispatch, cnclep_combine)
|
||
|
||
@dataclass
|
||
class MoeGroupInfo:
|
||
tp_rank: int
|
||
tp_size: int
|
||
dp_rank: int
|
||
dp_size: int
|
||
moe_tp_size: int
|
||
moe_tp_rank: int
|
||
moe_ep_size: int
|
||
moe_ep_rank: int
|
||
moe_group: Any
|
||
moe_kwargs: dict
|
||
|
||
def __init__(self):
|
||
self.tp_rank = get_tp_group().rank_in_group
|
||
self.tp_size = get_tp_group().world_size
|
||
self.dp_rank = get_dp_group().rank_in_group
|
||
self.dp_size = get_dp_group().world_size
|
||
|
||
self.moe_tp_size = get_moe_tensor_parallel_world_size()
|
||
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
||
self.moe_tp_group = get_moe_tensor_parallel_group()
|
||
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
||
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
||
self.moe_ep_group = get_moe_expert_parallel_group()
|
||
self.moe_group = self.moe_ep_group if self.moe_ep_size > 1 else self.moe_tp_group
|
||
self.moe_kwargs = {"tp_group": self.moe_tp_group}
|
||
|
||
|
||
class SqrtSoftPlusTopK(torch.nn.Module):
|
||
|
||
def __init__(self,
|
||
score_func: str,
|
||
use_hash: bool,
|
||
n_routed_experts: int,
|
||
n_activated_experts: int,
|
||
route_scale: float,
|
||
vocab_size: int,
|
||
prefix: str = ""):
|
||
super().__init__()
|
||
self.topk = n_activated_experts
|
||
self.n_activated_experts = n_activated_experts
|
||
self.score_func = score_func
|
||
self.route_scale = route_scale
|
||
self.use_hash = use_hash
|
||
self.n_routed_experts = n_routed_experts
|
||
self.vocab_size = vocab_size
|
||
if self.use_hash:
|
||
self.tid2eid = nn.Parameter(
|
||
torch.randint(0,
|
||
self.n_activated_experts,
|
||
(self.vocab_size, self.n_activated_experts),
|
||
dtype=torch.int32),
|
||
requires_grad=False,
|
||
)
|
||
self.bias = None
|
||
else:
|
||
self.tid2eid = None
|
||
self.bias = nn.Parameter(torch.empty(self.n_routed_experts, dtype=torch.float32), requires_grad=False)
|
||
|
||
def forward(self, scores: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
assert self.score_func == "sqrtsoftplus"
|
||
return mlu_ops.moe_softplus_topk(
|
||
scores,
|
||
self.topk,
|
||
input_ids,
|
||
self.tid2eid,
|
||
self.bias,
|
||
self.route_scale,
|
||
)
|
||
|
||
|
||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||
'''
|
||
=============================
|
||
Modify by vllm_mlu
|
||
=============================
|
||
@brief: comment out decorator torch.compiler to avoid triton bug for torch_mlu 2.9.1
|
||
'''
|
||
# @torch.compile(
|
||
# dynamic=True,
|
||
# backend=current_platform.simple_compile_backend,
|
||
# options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
|
||
# )
|
||
'''
|
||
==================
|
||
End of MLU Hijack
|
||
==================
|
||
'''
|
||
def grouped_topk(
|
||
hidden_states: torch.Tensor,
|
||
gating_output: torch.Tensor,
|
||
topk: int,
|
||
renormalize: bool,
|
||
num_expert_group: int = 0,
|
||
topk_group: int = 0,
|
||
scoring_func: str = "softmax",
|
||
routed_scaling_factor: float = 1.0,
|
||
e_score_correction_bias: torch.Tensor | None = None,
|
||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
if (
|
||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||
and current_platform.is_cuda()
|
||
and num_expert_group <= 32
|
||
and topk <= 32
|
||
and e_score_correction_bias is not None
|
||
):
|
||
return fused_grouped_topk(
|
||
hidden_states=hidden_states,
|
||
gating_output=gating_output,
|
||
topk=topk,
|
||
renormalize=renormalize,
|
||
e_score_correction_bias=e_score_correction_bias,
|
||
num_expert_group=num_expert_group,
|
||
topk_group=topk_group,
|
||
scoring_func=scoring_func,
|
||
routed_scaling_factor=routed_scaling_factor,
|
||
)
|
||
|
||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||
|
||
if scoring_func == "softmax":
|
||
scores = torch.softmax(gating_output, dim=-1)
|
||
elif scoring_func == "sigmoid":
|
||
scores = gating_output.sigmoid()
|
||
else:
|
||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||
|
||
num_token = scores.size(0)
|
||
if e_score_correction_bias is not None:
|
||
# Store original scores before applying correction bias. We use biased
|
||
# scores for expert selection but original scores for routing weights
|
||
original_scores = scores
|
||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||
group_scores = (
|
||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||
)
|
||
else:
|
||
group_scores = (
|
||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||
) # [n, n_group]
|
||
|
||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||
use_sorted = vllm_is_batch_invariant()
|
||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||
1
|
||
] # [n, top_k_group]
|
||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||
score_mask = (
|
||
group_mask.unsqueeze(-1)
|
||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||
.reshape(num_token, -1)
|
||
) # [n, e]
|
||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||
|
||
if e_score_correction_bias is not None:
|
||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||
# Use original unbiased scores for the routing weights
|
||
topk_weights = original_scores.gather(1, topk_ids)
|
||
else:
|
||
topk_weights, topk_ids = torch.topk(
|
||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||
)
|
||
|
||
if renormalize:
|
||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||
|
||
if routed_scaling_factor != 1.0:
|
||
topk_weights = topk_weights * routed_scaling_factor
|
||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||
|
||
|
||
class SparseMoeMlp(nn.Module):
|
||
"""
|
||
Tensor Parallel evenly splits each expert's weight and distributes them to different ranks,
|
||
which means each rank holds partial weight of all experts.
|
||
While Expert Parallel evenly distributes some of the experts' full weight to different ranks,
|
||
which means each rank holds part of the experts' full weight.
|
||
|
||
As a result, each rank in the Tensor Parallel group receives all tokens' hidden states for all experts,
|
||
then computes using the partial weights, while for Expert Parallel, each rank only receives
|
||
part of tokens' hidden states for experts on this rank, then computes using the full weights.
|
||
|
||
When both Tensor Parallel and Expert Parallel are enabled, each rank handles
|
||
a portion of the expert weights matrices (as in EP mode) and these weights are further sliced
|
||
across ranks (as in TP mode). This hybrid approach aims to balance the workload more evenly across ranks,
|
||
enhancing efficiency and reducing the likelihood of bottlenecks associated with EP mode alone.
|
||
"""
|
||
reduce_weight : torch.Tensor = None
|
||
expert_id : torch.Tensor = None
|
||
is_expert_avg : bool = False
|
||
max_batched_token : int = 2048
|
||
random_idx : int = 0
|
||
|
||
def __init__(
|
||
self,
|
||
num_experts: int,
|
||
top_k: int,
|
||
hidden_size: int,
|
||
intermediate_size: int,
|
||
up_proj_name: str,
|
||
is_gated: bool,
|
||
down_proj_name: str,
|
||
has_bias: bool,
|
||
skip_bias_add: bool = False,
|
||
renormalize:bool = False,
|
||
hidden_act: str = "silu",
|
||
params_dtype: torch.dtype | None = None,
|
||
quant_config: QuantizationConfig | None = None,
|
||
is_use_fused_moe: bool = False,
|
||
expert_group: int | None = 1,
|
||
topk_group: int | None = 1,
|
||
scoring_func: str = "softmax",
|
||
topk_method: str = "",
|
||
routed_scaling_factor: float = 1.0,
|
||
tp_group: Any = None,
|
||
use_all2all: bool = False,
|
||
use_hash: bool = False,
|
||
vocab_size: int = 0,
|
||
prefix: str = "",
|
||
init_avg_moe: bool = True,
|
||
):
|
||
super().__init__()
|
||
if tp_group is None:
|
||
self.tp_rank = get_tensor_model_parallel_rank()
|
||
self.tp_size = get_tensor_model_parallel_world_size()
|
||
self.tp_group = get_tp_group()
|
||
else:
|
||
self.tp_rank = tp_group.rank_in_group
|
||
self.tp_size = tp_group.world_size
|
||
self.tp_group = tp_group
|
||
self.use_hash = use_hash
|
||
self.num_total_experts = num_experts
|
||
self.top_k = top_k
|
||
self.hidden_size = hidden_size
|
||
self.intermediate_size = intermediate_size
|
||
self.up_proj_name = up_proj_name
|
||
self.is_gated = is_gated
|
||
self.down_proj_name = down_proj_name
|
||
self.has_bias = has_bias
|
||
self.renormalize = renormalize
|
||
self.hidden_act = hidden_act
|
||
self.quant_config = quant_config
|
||
self.is_use_fused_moe = is_use_fused_moe
|
||
self.expert_group = expert_group
|
||
self.topk_group = topk_group
|
||
self.scoring_func = scoring_func
|
||
self.routed_scaling_factor = routed_scaling_factor
|
||
self.use_all2all = use_all2all
|
||
self.vocab_size = vocab_size
|
||
# fused_moe doesn't support weightonly quantization
|
||
if isinstance(quant_config, WeightOnlyConfig):
|
||
self.is_use_fused_moe = False
|
||
|
||
if params_dtype is None:
|
||
params_dtype = torch.get_default_dtype()
|
||
self.params_dtype = params_dtype
|
||
# [num_bytes_hidden_states, num_bytes_reduce_weights, num_bytes_expert_id]
|
||
self.precompute_dim_bytes_list: List[int] | None = None
|
||
# sum(self.precompute_dim_bytes_list)
|
||
self.precompute_dim_bytes = -1
|
||
|
||
moe_group_info = MoeGroupInfo()
|
||
self.moe_tp_size = moe_group_info.moe_tp_size
|
||
self.moe_tp_rank = moe_group_info.moe_tp_rank
|
||
self.moe_ep_size = moe_group_info.moe_ep_size
|
||
self.moe_ep_rank = moe_group_info.moe_ep_rank
|
||
self.dp_size = moe_group_info.dp_size
|
||
self.dp_rank = moe_group_info.dp_rank
|
||
self.moe_group = moe_group_info.moe_group
|
||
self.moe_kwargs = moe_group_info.moe_kwargs
|
||
|
||
vllm_config = get_current_vllm_config()
|
||
model_config = getattr(vllm_config, "model_config", None)
|
||
hf_text_config = getattr(model_config, "hf_text_config", None)
|
||
self.model_type = getattr(hf_text_config, "model_type", "")
|
||
|
||
if (init_avg_moe and
|
||
VLLM_AVG_MOE_EN and not SparseMoeMlp.is_expert_avg):
|
||
n_tokens = SparseMoeMlp.max_batched_token * self.dp_size
|
||
expert_group = self.moe_ep_size
|
||
val = 1.0 / float(num_experts)
|
||
SparseMoeMlp.reduce_weight = torch.full((n_tokens, top_k), val, device="mlu", dtype=torch.float32)
|
||
import math
|
||
if VLLM_RANDOM_MOE_EN:
|
||
import numpy as np
|
||
# example deepseekv2: experts 160 topk 6
|
||
# avg list: 92, 8, 88, 45, 99, 9,... 118, 142, 116, 57, 104, 6,......
|
||
array = np.stack([np.random.permutation(num_experts)[:top_k] for _ in range(n_tokens)])
|
||
table = torch.from_numpy(array.flatten()).to(device="mlu", dtype=torch.int32)
|
||
else:
|
||
# example deepseekv2: experts 160
|
||
# avg list: 0,20,40,60,80...120,140, 1,21,...121,141, 2...142, ...... 19,...159, 0,20,......
|
||
batch_table = math.ceil(n_tokens * top_k / num_experts) * num_experts
|
||
hi_val = batch_table // num_experts
|
||
table = (torch.arange(hi_val * num_experts, device="mlu", dtype=torch.int32) % num_experts).view(
|
||
hi_val, expert_group, num_experts // expert_group).transpose(1, 2)
|
||
SparseMoeMlp.expert_id = table.flatten()[:n_tokens * top_k].view(n_tokens, top_k)
|
||
SparseMoeMlp.is_expert_avg = True
|
||
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
|
||
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
|
||
self.skip_bias_add = True if self.moe_tp_rank > 0 else False
|
||
|
||
assert self.num_total_experts >= self.moe_ep_size, (
|
||
f"need num_total_experts:{self.num_total_experts} >= moe_ep_size:{self.moe_ep_size}")
|
||
|
||
assert self.intermediate_size % self.moe_tp_size == 0, (
|
||
f"need intermediate_size:{self.intermediate_size} % moe_tp_size:{self.moe_tp_size} == 0")
|
||
|
||
self.num_experts_per_rank = (self.num_total_experts + self.moe_ep_size - 1) // self.moe_ep_size
|
||
if self.moe_ep_rank + 1 == self.moe_ep_size and self.num_total_experts % self.moe_ep_size:
|
||
self.num_experts_per_rank = self.num_total_experts % self.moe_ep_size
|
||
|
||
self.start_expert_id = self.moe_ep_rank * ((self.num_total_experts + self.moe_ep_size - 1) // self.moe_ep_size)
|
||
self.end_expert_id = self.start_expert_id + self.num_experts_per_rank
|
||
|
||
# Gate always runs at half / full precision for now.
|
||
self.gate = ReplicatedLinear(
|
||
self.hidden_size,
|
||
self.num_total_experts,
|
||
bias=False,
|
||
params_dtype=self.params_dtype,
|
||
quant_config=None,
|
||
)
|
||
if self.is_deepseek_v4:
|
||
self.deepseekv4_topk = SqrtSoftPlusTopK(
|
||
score_func=self.scoring_func,
|
||
use_hash=self.use_hash,
|
||
n_routed_experts=self.num_total_experts,
|
||
n_activated_experts=self.top_k,
|
||
route_scale=self.routed_scaling_factor,
|
||
vocab_size=self.vocab_size,
|
||
prefix=f"{prefix}.topk",
|
||
)
|
||
if topk_method == "noaux_tc":
|
||
self.gate.e_score_correction_bias = nn.Parameter(
|
||
torch.empty(self.num_total_experts, device="mlu"))
|
||
else:
|
||
self.gate.e_score_correction_bias = None
|
||
self.is_fp8_block_wise = (isinstance(self.quant_config, Fp8Config)
|
||
and (self.quant_config.weight_block_size is not None))
|
||
if self.is_fp8_block_wise:
|
||
self.experts = FusedMoE(
|
||
num_experts=self.num_experts_per_rank,
|
||
top_k=self.top_k,
|
||
hidden_size=self.hidden_size,
|
||
intermediate_size=self.intermediate_size,
|
||
reduce_results=False,
|
||
renormalize=self.renormalize,
|
||
quant_config=self.quant_config,
|
||
use_grouped_topk=True,
|
||
num_expert_group=self.expert_group,
|
||
topk_group=self.topk_group,
|
||
prefix=f"{prefix}.experts",
|
||
scoring_func=self.scoring_func,
|
||
e_score_correction_bias=self.gate.e_score_correction_bias)
|
||
else:
|
||
self.experts = nn.ModuleList([
|
||
FeedForward(hidden_size=self.hidden_size,
|
||
intermediate_size=self.intermediate_size,
|
||
hidden_act=self.hidden_act,
|
||
up_proj_name=self.up_proj_name,
|
||
is_gated=self.is_gated,
|
||
down_proj_name=self.down_proj_name,
|
||
bias=self.has_bias,
|
||
quant_config=self.quant_config,
|
||
skip_bias_add=self.skip_bias_add,
|
||
reduce_results=False,
|
||
prefix=f"experts.{idx}",
|
||
**self.moe_kwargs) for idx in range(self.num_experts_per_rank)
|
||
])
|
||
|
||
self.init_pack_param()
|
||
|
||
@property
|
||
def is_deepseek_v4(self):
|
||
return self.scoring_func == 'sqrtsoftplus'
|
||
|
||
@property
|
||
def is_kimi_k2(self):
|
||
kimi_k2_scoring_func = "sigmoid"
|
||
kimi_k2_expert_group_num = 1
|
||
kimi_k2_experts_num = 384
|
||
return (self.scoring_func == kimi_k2_scoring_func
|
||
and self.expert_group == kimi_k2_expert_group_num
|
||
and self.num_total_experts == kimi_k2_experts_num)
|
||
|
||
@property
|
||
def is_glm4_moe(self):
|
||
return self.model_type == "glm4_moe"
|
||
|
||
def init_pack_param(self):
|
||
self.w13 = None
|
||
self.w2 = None
|
||
self.b13 = None
|
||
self.b2 = None
|
||
self.w13_scale = None
|
||
self.w2_scale = None
|
||
self.a13_scale = None
|
||
self.a13_scale_all_experts = None
|
||
self.a2_scale = None
|
||
self.pack_params_done = False
|
||
self.pack_params_after_loading_done = False
|
||
|
||
|
||
def map_param_data(self, param_list, is_use_first_data=False):
|
||
if len(param_list) == 0:
|
||
return None
|
||
|
||
if is_use_first_data or len(param_list) == 1:
|
||
first_data = param_list[0].data
|
||
for param in param_list[1: -1]:
|
||
param.data = first_data
|
||
if is_use_first_data:
|
||
out_param = first_data.view_as(param_list[0])
|
||
else:
|
||
out_param = first_data.view(len(param_list), *first_data.shape)
|
||
else:
|
||
packed_param = torch._utils._flatten_dense_tensors(param_list)
|
||
data_list = torch._utils._unflatten_dense_tensors(packed_param, param_list)
|
||
for data, param in zip(data_list, param_list):
|
||
param.data = data
|
||
out_param = packed_param.view(len(param_list), *data_list[0].shape)
|
||
|
||
torch.mlu.empty_cache()
|
||
|
||
return out_param
|
||
|
||
|
||
def pack_unquantized_params(self, w13, w2, b13, b2):
|
||
for expert in self.experts:
|
||
up_proj = getattr(expert, self.up_proj_name)
|
||
down_proj = getattr(expert, self.down_proj_name)
|
||
w13.append(up_proj.weight)
|
||
w2.append(down_proj.weight)
|
||
if self.has_bias:
|
||
b13.append(up_proj.bias)
|
||
b2.append(down_proj.bias)
|
||
|
||
|
||
def pack_smoothquant_params(self, w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale):
|
||
for expert in self.experts:
|
||
up_proj = getattr(expert, self.up_proj_name)
|
||
down_proj = getattr(expert, self.down_proj_name)
|
||
w13.append(up_proj.qweight)
|
||
w2.append(down_proj.qweight)
|
||
if self.has_bias:
|
||
b13.append(up_proj.bias)
|
||
b2.append(down_proj.bias)
|
||
w13_scale.append(up_proj.per_channel_scale)
|
||
w2_scale.append(down_proj.per_channel_scale)
|
||
if self.quant_config.input_quant_method == "per_token":
|
||
a13_scale.append(up_proj.smooth)
|
||
a2_scale.append(down_proj.smooth)
|
||
else:
|
||
a13_scale.append(up_proj.scale_to_int)
|
||
a2_scale.append(down_proj.scale_to_int)
|
||
|
||
|
||
def pack_weightonly_params(self, w13, w2, b13, b2, w13_scale, w2_scale):
|
||
for expert in self.experts:
|
||
up_proj = getattr(expert, self.up_proj_name)
|
||
down_proj = getattr(expert, self.down_proj_name)
|
||
w13.append(up_proj.qweight)
|
||
w2.append(down_proj.qweight)
|
||
if self.has_bias:
|
||
b13.append(up_proj.bias)
|
||
b2.append(down_proj.bias)
|
||
w13_scale.append(up_proj.scales)
|
||
w2_scale.append(down_proj.scales)
|
||
|
||
def pack_fp8_params_without_activation_scheme(self, w13, w2, b13, b2, w13_scale, w2_scale):
|
||
for expert in self.experts:
|
||
up_proj = getattr(expert, self.up_proj_name)
|
||
down_proj = getattr(expert, self.down_proj_name)
|
||
w13.append(up_proj.weight)
|
||
w2.append(down_proj.weight)
|
||
if self.has_bias:
|
||
b13.append(up_proj.bias)
|
||
b2.append(down_proj.bias)
|
||
w13_scale.append(up_proj.weight_scale)
|
||
w2_scale.append(down_proj.weight_scale)
|
||
|
||
|
||
def pack_params(self):
|
||
if self.pack_params_done or self.is_fp8_block_wise:
|
||
return
|
||
|
||
w13 = []
|
||
w2 = []
|
||
b13 = []
|
||
b2 = []
|
||
w13_scale = []
|
||
w2_scale = []
|
||
a13_scale = []
|
||
a2_scale = []
|
||
|
||
if self.quant_config is None:
|
||
self.pack_unquantized_params(w13, w2, b13, b2)
|
||
elif isinstance(self.quant_config, SmoothQuantConfig):
|
||
self.pack_smoothquant_params(w13, w2, b13, b2, w13_scale, w2_scale, a13_scale, a2_scale)
|
||
elif isinstance(self.quant_config, WeightOnlyConfig):
|
||
self.pack_weightonly_params(w13, w2, b13, b2, w13_scale, w2_scale)
|
||
elif isinstance(self.quant_config, Fp8Config) and self.quant_config.activation_scheme == 'dynamic':
|
||
self.pack_fp8_params_without_activation_scheme(w13, w2, b13, b2, w13_scale, w2_scale)
|
||
else:
|
||
raise ValueError(f'Unsupported quantization:{self.quant_config}')
|
||
|
||
# pack weight
|
||
self.w13 = self.map_param_data(w13)
|
||
self.w2 = self.map_param_data(w2)
|
||
|
||
# pack bias
|
||
if self.has_bias:
|
||
self.b13 = self.map_param_data(b13)
|
||
# NOTE: The bias for fc2 is only applied on tp_rank 0. If we added it on all nodes the allreduce() would
|
||
# contain multiple copies of the bias. The bias on other node will be ignored, and may be set to nullptr
|
||
if self.skip_bias_add is False:
|
||
self.b2 = self.map_param_data(b2)
|
||
|
||
|
||
# pack weight scale
|
||
if len(w13_scale) > 0:
|
||
self.w13_scale = self.map_param_data(w13_scale)
|
||
if len(w2_scale) > 0:
|
||
self.w2_scale = self.map_param_data(w2_scale)
|
||
|
||
# pack activate scale
|
||
if len(a13_scale) > 0:
|
||
self.a13_scale = self.map_param_data(a13_scale)
|
||
if len(a2_scale) > 0:
|
||
self.a2_scale = self.map_param_data(a2_scale)
|
||
|
||
self.pack_params_done = True
|
||
|
||
def pack_params_after_loading(self):
|
||
if self.pack_params_after_loading_done or self.is_fp8_block_wise:
|
||
return
|
||
|
||
if isinstance(self.quant_config, SmoothQuantConfig) and self.quant_config.group_size > 1 and self.is_use_fused_moe:
|
||
assert self.w13_scale is not None and self.w2_scale is not None, "w13_scale and w2_scale must be not None"
|
||
self.w13_scale = self.w13_scale.permute(2, 0, 1).contiguous()
|
||
self.w2_scale = self.w2_scale.permute(2, 0, 1).contiguous()
|
||
|
||
# pack smooth variables for moe_quantize if fp8
|
||
# FIXME: replace smooth to None after tmo supports.
|
||
if isinstance(self.quant_config, Fp8Config):
|
||
expert_size = self.w13.shape[0]
|
||
fp8_smooth_2_hidden_size = self.w13.shape[1] // 2 if self.is_gated else self.w13.shape[1]
|
||
self.fp8_smooth_1 = torch.ones([expert_size, self.hidden_size], device=self.w13.device, dtype=torch.float32)
|
||
self.fp8_smooth_2 = torch.ones([expert_size, fp8_smooth_2_hidden_size], device=self.w13.device, dtype=torch.float32)
|
||
|
||
self.pack_params_done = True
|
||
self.pack_params_after_loading_done = True
|
||
|
||
def get_precompute_dim_bytes_list(self, hidden_states_dtype: torch.dtype) -> List[int]:
|
||
'''
|
||
get the number of bytes of the hidden dimension corresponding to
|
||
hidden_states, reduce_weight, and expert_id, respectively.
|
||
'''
|
||
if not self.precompute_dim_bytes_list:
|
||
hidden_states_size = self.hidden_size * get_dtype_size(hidden_states_dtype)
|
||
reduce_weights_size = self.top_k * get_dtype_size(torch.float)
|
||
expert_id_size = self.top_k * get_dtype_size(torch.int32)
|
||
self.precompute_dim_bytes_list = [
|
||
hidden_states_size, reduce_weights_size, expert_id_size
|
||
]
|
||
return self.precompute_dim_bytes_list
|
||
|
||
def get_precompute_dim_bytes(self, hidden_states_dtype: torch.dtype) -> int:
|
||
'''
|
||
get the hidden dimension in bytes for a packed hidden states that
|
||
include
|
||
[hidden_states | reduce_weights | expert_id]
|
||
'''
|
||
if self.precompute_dim_bytes < 0:
|
||
self.precompute_dim_bytes = sum(self.get_precompute_dim_bytes_list(hidden_states_dtype))
|
||
return self.precompute_dim_bytes
|
||
|
||
def reduce_results(self, final_hidden_states: torch.Tensor, reduce_results: bool = True):
|
||
if reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||
# Default set to False. (May have to add shared expert outputs.)
|
||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||
final_hidden_states, self.moe_group)
|
||
return final_hidden_states
|
||
|
||
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor | None = None) -> torch.Tensor:
|
||
orig_hidden_states_shape = hidden_states.shape
|
||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||
# expert_logits: [num_tokens, self.num_experts_per_rank]
|
||
expert_logits, _ = self.gate(hidden_states)
|
||
final_hidden_states = self.forward_experts(hidden_states, expert_logits, residual)
|
||
final_hidden_states = self.reduce_results(final_hidden_states)
|
||
output = final_hidden_states.view(orig_hidden_states_shape)
|
||
return output
|
||
|
||
def precompute_weight_expert_id(
|
||
self,
|
||
packed: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
'''
|
||
pre compute gate and softmax_topk/sigmoid_topk, and fill the weight and
|
||
expert_id part as below
|
||
in = [ hidden_states | ------ | --------- ]
|
||
[ bf16 | fp32 | int32 ]
|
||
|
||
out = [ hidden_states | weight | expert_id ]
|
||
[ bf16 | fp32 | int32 ]
|
||
'''
|
||
hidden_states_size, weight_size, expert_id_size = self.get_precompute_dim_bytes_list(packed.dtype)
|
||
packed_int8 = packed.view(torch.int8)
|
||
hidden_states = packed_int8[:, : hidden_states_size].view(packed.dtype)
|
||
router_logits, _ = self.gate(hidden_states)
|
||
topk=self.top_k
|
||
renormalized=self.renormalize
|
||
reduce_weight = packed_int8[:, hidden_states_size : hidden_states_size + weight_size].view(torch.float)
|
||
expert_id = packed_int8[:, hidden_states_size + weight_size :].view(torch.int32)
|
||
if self.scoring_func == "softmax":
|
||
reduce_weight, expert_id = mlu_ops.moe_softmax_topk(router_logits, topk, renormalized, self.expert_group,
|
||
self.topk_group, route_scale=self.routed_scaling_factor,
|
||
reduce_weight=reduce_weight,
|
||
expert_id=expert_id)
|
||
elif self.scoring_func == "sigmoid":
|
||
reduce_weight, expert_id = mlu_ops.moe_sigmoid_topk(router_logits, topk, renormalized,
|
||
self.expert_group, self.topk_group,
|
||
self.routed_scaling_factor,
|
||
self.gate.e_score_correction_bias,
|
||
reduce_weight=reduce_weight,
|
||
expert_id=expert_id)
|
||
else:
|
||
raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
|
||
return packed
|
||
|
||
def forward_experts(self, hidden_states, expert_logits, residual: torch.Tensor | None = None,
|
||
shared_output: torch.Tensor | None = None,
|
||
input_ids: torch.Tensor | None = None):
|
||
assert not (residual is not None and shared_output is not None)
|
||
residual_ = None if self.tp_rank > 0 else residual
|
||
|
||
# change only for deepseek_model without residual_
|
||
if shared_output is not None:
|
||
residual_ = shared_output
|
||
|
||
if self.is_fp8_block_wise:
|
||
output = self.experts(hidden_states=hidden_states,
|
||
router_logits=expert_logits) * self.routed_scaling_factor
|
||
if residual_ is not None:
|
||
output = output + residual_
|
||
return output
|
||
|
||
use_forward_group_experts = (self.is_use_fused_moe
|
||
and (
|
||
self.is_kimi_k2
|
||
or self.is_glm4_moe
|
||
or self.is_deepseek_v4
|
||
or self.expert_group != 1)
|
||
)
|
||
if use_forward_group_experts:
|
||
final_hidden_states = self.forward_group_experts(
|
||
hidden_states,
|
||
expert_logits,
|
||
residual_,
|
||
input_ids=input_ids,
|
||
)
|
||
elif self.is_use_fused_moe:
|
||
self.pack_params()
|
||
self.pack_params_after_loading()
|
||
final_hidden_states = mlu_ops.fused_moe(hidden_states=hidden_states,
|
||
gating_output=expert_logits,
|
||
w1=self.w13,
|
||
w2=self.w2,
|
||
bias1=self.b13,
|
||
bias2=self.b2,
|
||
residual=residual_,
|
||
input_smooth=self.a13_scale,
|
||
act_smooth=self.a2_scale,
|
||
w1_scale=self.w13_scale,
|
||
w2_scale=self.w2_scale,
|
||
topk=self.top_k,
|
||
renormalize=self.renormalize,
|
||
gated=self.is_gated,
|
||
act_mode=self.hidden_act,
|
||
start_expert_id=self.start_expert_id,
|
||
avg_moe=VLLM_AVG_MOE_EN,
|
||
class_reduce_weight=SparseMoeMlp.reduce_weight,
|
||
class_expert_id=SparseMoeMlp.expert_id,
|
||
)
|
||
else:
|
||
final_hidden_states = self.forward_experts_nofused(hidden_states, expert_logits)
|
||
if residual_ is not None:
|
||
final_hidden_states = final_hidden_states + residual_
|
||
return final_hidden_states
|
||
|
||
|
||
def forward_experts_nofused(self, hidden_states, expert_logits):
|
||
hidden_states_shape = hidden_states.shape
|
||
if self.scoring_func == "softmax":
|
||
topk_values, topk_indices = self.topk_softmax(expert_logits)
|
||
elif self.scoring_func == "sigmoid":
|
||
gating_output = expert_logits.to(torch.float32)
|
||
gating_output = gating_output.view(-1, gating_output.size(-1))
|
||
topk_values, topk_indices = grouped_topk(hidden_states, gating_output, self.top_k, self.renormalize,
|
||
self.expert_group, self.topk_group, self.scoring_func,
|
||
self.routed_scaling_factor, self.gate.e_score_correction_bias)
|
||
topk_values = topk_values.to(hidden_states.dtype)
|
||
topk_indices = topk_indices.to(torch.int64)
|
||
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx(
|
||
topk_indices)
|
||
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
|
||
# expand_token_count and expand_cusum_token_count has item but the value is all zero
|
||
# so this rank should only return final_hidden_states with zero value
|
||
if expand_gather_idx.numel() == 0:
|
||
final_hidden_states = torch.zeros_like(hidden_states,
|
||
dtype=hidden_states.dtype,
|
||
device=hidden_states.device)
|
||
return final_hidden_states
|
||
|
||
expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx)
|
||
|
||
expand_output_list = []
|
||
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
|
||
1] - cusum_token_count[self.start_expert_id]
|
||
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count):
|
||
if num_tokens_per_expert > 0:
|
||
expert_hidden_states = expand_hidden_states[
|
||
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
|
||
expert_output = self.experts[expert_idx](expert_hidden_states)
|
||
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
|
||
expand_output_list.append(expert_output)
|
||
expand_output = torch.cat(expand_output_list, dim=0)
|
||
final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape,
|
||
topk_values)
|
||
|
||
return final_hidden_states
|
||
|
||
def forward_group_experts(self, hidden_states, gating_output, residual_, input_ids: torch.Tensor | None = None):
|
||
# determine if hidden_states packs reduce_weight and expert_id in it,
|
||
# and if so, extract them.
|
||
orig_dtype = hidden_states.dtype
|
||
device = hidden_states.device
|
||
hidden_states_int8 = hidden_states.view(torch.int8)
|
||
hidden_states_size, weight_size, _ = self.get_precompute_dim_bytes_list(orig_dtype)
|
||
packed_dim = self.get_precompute_dim_bytes(orig_dtype)
|
||
is_precompute_weight_expert_id: bool = (hidden_states_int8.shape[1] == packed_dim)
|
||
if is_precompute_weight_expert_id:
|
||
assert gating_output is None
|
||
hidden_states = hidden_states_int8[:, : hidden_states_size].view(orig_dtype)
|
||
reduce_weight = hidden_states_int8[:, hidden_states_size : hidden_states_size + weight_size].view(torch.float)
|
||
expert_id = hidden_states_int8[:, hidden_states_size + weight_size :].view(torch.int32)
|
||
|
||
is_fp8_quant = isinstance(self.quant_config, Fp8Config)
|
||
ori_input_shape = hidden_states.shape
|
||
dtype = hidden_states.dtype
|
||
self.pack_params()
|
||
self.pack_params_after_loading()
|
||
w1=self.w13.to(device) if self.w13 is not None else None
|
||
w2=self.w2.to(device) if self.w2 is not None else None
|
||
bias1=self.b13.to(device) if self.b13 is not None else None
|
||
bias2=self.b2.to(device) if self.b2 is not None else None
|
||
input_smooth=self.a13_scale.to(device) if self.a13_scale is not None else None
|
||
act_smooth=self.a2_scale.to(device) if self.a2_scale is not None else None
|
||
w1_scale=self.w13_scale.to(device) if self.w13_scale is not None else None
|
||
w2_scale=self.w2_scale.to(device) if self.w2_scale is not None else None
|
||
topk=self.top_k
|
||
renormalized=self.renormalize
|
||
gated=self.is_gated
|
||
act_mode=self.hidden_act
|
||
quant_input=None
|
||
|
||
start_expert_id=self.start_expert_id
|
||
expert_size = w1.size(0)
|
||
max_m = hidden_states.shape[0]
|
||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||
residual_ = residual_.view(-1, residual_.size(-1)) if residual_ is not None else None
|
||
# Check smooth quant parameters.
|
||
per_token_sq = False
|
||
if not is_fp8_quant:
|
||
check_list = [input_smooth, act_smooth, w1_scale, w2_scale]
|
||
if all(x is not None for x in check_list):
|
||
per_token_sq = True
|
||
|
||
if not (all(x is None for x in check_list) or all(x is not None for x in check_list)):
|
||
raise ValueError("input_smooth, act_smooth, w1_scale and w2_scale must be present "
|
||
"and absent at the same time.")
|
||
|
||
# softmax_topk
|
||
if not is_precompute_weight_expert_id:
|
||
gating_output = gating_output.view(-1, gating_output.size(-1))
|
||
if self.scoring_func == "softmax":
|
||
reduce_weight, expert_id = mlu_ops.moe_softmax_topk(gating_output, topk, renormalized, self.expert_group,
|
||
self.topk_group, route_scale=self.routed_scaling_factor)
|
||
elif self.scoring_func == "sigmoid":
|
||
reduce_weight, expert_id = mlu_ops.moe_sigmoid_topk(gating_output, topk, renormalized,
|
||
self.expert_group, self.topk_group,
|
||
self.routed_scaling_factor,
|
||
self.gate.e_score_correction_bias)
|
||
elif self.scoring_func == "sqrtsoftplus":
|
||
assert hasattr(self,"deepseekv4_topk")
|
||
reduce_weight, expert_id = self.deepseekv4_topk(
|
||
gating_output,
|
||
input_ids,
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
|
||
|
||
if VLLM_RANDOM_MOE_EN:
|
||
n_tokens = hidden_states.shape[0]
|
||
token_len = SparseMoeMlp.expert_id.size(0)
|
||
SparseMoeMlp.random_idx = 0 if token_len == n_tokens else (SparseMoeMlp.random_idx+1) % (token_len-n_tokens)
|
||
n_tokens = hidden_states.shape[0]
|
||
reduce_weight = SparseMoeMlp.reduce_weight[:n_tokens]
|
||
expert_id = SparseMoeMlp.expert_id[SparseMoeMlp.random_idx: SparseMoeMlp.random_idx + n_tokens]
|
||
elif VLLM_AVG_MOE_EN:
|
||
n_tokens = hidden_states.shape[0]
|
||
reduce_weight = SparseMoeMlp.reduce_weight[:n_tokens]
|
||
expert_id = SparseMoeMlp.expert_id[:n_tokens]
|
||
# gen_idx
|
||
expand_idx, combine_idx, token_count, cusum_token_count = mlu_ops.moe_gen_idx(expert_id, self.num_total_experts)
|
||
# check quant
|
||
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
|
||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||
hidden_states,
|
||
self.fp8_smooth_1,
|
||
zero=None,
|
||
token_count=token_count[start_expert_id:start_expert_id+expert_size],
|
||
gather_index=expand_idx,
|
||
gather_index_start_position=cusum_token_count[start_expert_id].unsqueeze(0),
|
||
output=None,
|
||
output_scale=None,
|
||
dynamic_quant=True,
|
||
quant_type=torch.float8_e4m3fn
|
||
)
|
||
elif per_token_sq:
|
||
quant_input, input_scale = mlu_ops.moe_quantize(hidden_states,
|
||
input_smooth, None, token_count[start_expert_id:start_expert_id+expert_size], expand_idx,
|
||
cusum_token_count[start_expert_id].unsqueeze(0))
|
||
else:
|
||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||
hidden_states,
|
||
expand_idx,
|
||
cusum_token_count,
|
||
start_expert_id,
|
||
expert_size,
|
||
)
|
||
|
||
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
|
||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w1,
|
||
token_count[start_expert_id:start_expert_id+expert_size],
|
||
None, None, None, None,
|
||
input_scale, w1_scale, dtype, max_m)
|
||
else:
|
||
gemm_out = mlu_ops.group_gemm(expand_hidden_states, w1,
|
||
token_count[start_expert_id:start_expert_id+expert_size],
|
||
None, None, None, None, max_m)
|
||
# add_bias_active
|
||
if is_fp8_quant and self.quant_config.activation_quant_method == 'per_token':
|
||
act_out = mlu_ops.moe_active(gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2], bias=bias1, cusum_token_count=cusum_token_count, start_expert_id=start_expert_id, expert_size=expert_size)
|
||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||
act_out,
|
||
self.fp8_smooth_2,
|
||
zero=None,
|
||
token_count=token_count[start_expert_id:start_expert_id+expert_size],
|
||
gather_index=None,
|
||
gather_index_start_position=None,
|
||
output=quant_input[:,:act_out.shape[-1]],
|
||
output_scale=None,
|
||
dynamic_quant=True,
|
||
quant_type=torch.float8_e4m3fn
|
||
)
|
||
elif per_token_sq:
|
||
quant_input = quant_input[:, :gemm_out.shape[-1] // 2]
|
||
input_scale = input_scale[:gemm_out.shape[0]]
|
||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
|
||
token_count[start_expert_id:start_expert_id+expert_size],
|
||
output=quant_input,
|
||
output_scale=input_scale,
|
||
act_mode=act_mode,
|
||
is_gated=self.is_gated)
|
||
|
||
if (is_fp8_quant and self.quant_config.activation_quant_method == 'per_token') or per_token_sq:
|
||
# Remove the reference to gemm_out tensor.
|
||
# If that was the only reference, the tensor’s memory becomes eligible for deallocation
|
||
# So that we can reuse this memory for the new allocation of next gemm operation
|
||
del gemm_out
|
||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
|
||
token_count[start_expert_id:start_expert_id+expert_size],
|
||
None, None, None, None, input_scale, w2_scale, dtype, max_m)
|
||
else:
|
||
act_out = mlu_ops.moe_active(gemm_out, act_mode, gated, gemm_out[:,:gemm_out.shape[-1]//2], bias1, cusum_token_count, start_expert_id, expert_size)
|
||
gemm_out = mlu_ops.group_gemm(act_out, w2,
|
||
token_count[start_expert_id:start_expert_id+expert_size],
|
||
None, None, None, None, max_m)
|
||
|
||
# we reuse the memory of hidden_states to store the output
|
||
output = mlu_ops.moe_combine_result(
|
||
gemm_out, reduce_weight, combine_idx,
|
||
residual_, cusum_token_count, start_expert_id,
|
||
expert_size, bias2,
|
||
output=hidden_states if not is_precompute_weight_expert_id else None)
|
||
return output.view(ori_input_shape)
|
||
|
||
|
||
def topk_softmax(self, expert_logits):
|
||
# expert_logits: [num_tokens, self.num_experts_per_rank]
|
||
# topk_values: [num_tokens, self.top_k]
|
||
# topk_indices: [num_tokens, self.top_k]
|
||
if self.renormalize:
|
||
topk_values, topk_indices = torch.topk(expert_logits, self.top_k, dim=-1)
|
||
topk_values = torch.softmax(topk_values, -1)
|
||
else:
|
||
router_probs = torch.softmax(expert_logits, -1)
|
||
topk_values, topk_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
||
|
||
return topk_values, topk_indices
|
||
|
||
|
||
def generate_gather_idx(self, topk_indices):
|
||
device = topk_indices.device
|
||
# gather_expand_idx: [num_tokens * self.top_k]
|
||
sorted_expert_id, indices = topk_indices.flatten().sort()
|
||
gather_idx = indices // self.top_k
|
||
|
||
seqs = torch.arange(indices.numel(), dtype=indices.dtype, device=indices.device)
|
||
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
|
||
|
||
# token_count: [self.num_experts_per_rank]
|
||
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
|
||
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
|
||
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
|
||
# cusum_token_count: [self.num_experts_per_rank + 1]
|
||
cusum_token_count = torch.cat(
|
||
[torch.tensor([0], dtype=token_count.dtype, device=device),
|
||
token_count.cumsum(dim=0)])
|
||
|
||
num_tokens_before_expert = cusum_token_count[self.start_expert_id]
|
||
num_tokens_including_expert = cusum_token_count[self.end_expert_id]
|
||
|
||
expand_gather_idx = gather_idx[num_tokens_before_expert:num_tokens_including_expert]
|
||
expand_token_count = token_count[self.start_expert_id:self.end_expert_id]
|
||
|
||
return expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count
|
||
|
||
|
||
def expand_input(self, hidden_states, expand_gather_idx):
|
||
expand_hidden_states = hidden_states[expand_gather_idx]
|
||
return expand_hidden_states
|
||
|
||
|
||
def combine_moe(self, expand_output, scatter_idx, cusum_token_count, hidden_states_shape, topk_values):
|
||
num_tokens, hidden_size = hidden_states_shape
|
||
num_tokens_before_expert = cusum_token_count[self.start_expert_id]
|
||
num_tokens_after_expert = cusum_token_count[-1] - cusum_token_count[self.end_expert_id]
|
||
|
||
expand_output_before_expert = torch.zeros((num_tokens_before_expert, hidden_size),
|
||
dtype=expand_output.dtype,
|
||
device=expand_output.device)
|
||
expand_output_after_expert = torch.zeros((num_tokens_after_expert, hidden_size),
|
||
dtype=expand_output.dtype,
|
||
device=expand_output.device)
|
||
unscatted_output = torch.cat([expand_output_before_expert, expand_output, expand_output_after_expert], dim=0)
|
||
scatter_output = unscatted_output[scatter_idx]
|
||
hidden_states_weight = topk_values.flatten().unsqueeze(-1)
|
||
weighted_hidden_states = scatter_output * hidden_states_weight
|
||
unreduced_hidden_states = weighted_hidden_states.view(num_tokens, self.top_k, hidden_size)
|
||
final_hidden_states = unreduced_hidden_states.sum(dim=1)
|
||
|
||
return final_hidden_states
|
||
|
||
def prepare_for_cnclep(self, cnclep: CnclEP) -> None:
|
||
if cnclep.use_quant_dispatch:
|
||
self.prepare_for_cnclep_quant_dispatch(cnclep)
|
||
else:
|
||
self.prepare_for_cnclep_bf16(cnclep)
|
||
|
||
def prepare_for_cnclep_bf16(self, cnclep: CnclEP) -> None:
|
||
# prepare buffers for the forward process
|
||
buffer = cnclep.buffer
|
||
self.dispatch_send_buffer = buffer.dispatch_send_token_tensor
|
||
self.dispatch_recv_buffer = buffer.dispatch_recv_token_tensor
|
||
self.combine_send_buffer = buffer.combine_send_token_tensor
|
||
self.combine_recv_buffer = buffer.combine_recv_token_tensor
|
||
self.max_num_tokens_per_rank = cnclep.max_num_tokens_per_rank
|
||
|
||
# get sizes in bytes
|
||
self.dispatch_token_size = self.config.hidden_size * 2
|
||
# [nranks, 2]
|
||
self.dispatch_recv_layout = torch.empty((self.moe_ep_size, 2), dtype=torch.int32, device="mlu")
|
||
# [num_total_experts]
|
||
self.dispatch_recv_token_num = torch.empty((self.num_total_experts), dtype=torch.int32, device="mlu")
|
||
|
||
self.max_num_tokens_recv = self.max_num_tokens_per_rank * self.moe_ep_size
|
||
self.max_num_tokens_per_expert = divide(self.max_num_tokens_recv, self.top_k)
|
||
|
||
# input to the first groupgemm, in which tokens are ordered by experts.
|
||
input_recv_size = self.max_num_tokens_recv * self.dispatch_token_size
|
||
self.input_recv = (
|
||
self.combine_send_buffer[:input_recv_size]
|
||
.view(self.max_num_tokens_recv, self.dispatch_token_size)
|
||
)
|
||
# kept for code without compute-communication parallel, which may have
|
||
# become stale.
|
||
self.quant_input_recv = self.input_recv
|
||
|
||
def prepare_for_cnclep_quant_dispatch(self, cnclep: CnclEP) -> None:
|
||
# prepare smooth parameter for _all_ experts globally, which would be needed during
|
||
# input quantization before dispatch.
|
||
assert self.a13_scale is not None, "a13_scale has not been loaded"
|
||
self.a13_scale_all_experts = torch.zeros((self.num_total_experts, self.hidden_size),
|
||
dtype=self.a13_scale.dtype,
|
||
device=self.a13_scale.device)
|
||
torch.distributed.all_gather_into_tensor(self.a13_scale_all_experts,
|
||
self.a13_scale,
|
||
group=self.moe_group.device_group,
|
||
async_op=False)
|
||
|
||
# prepare buffers for the forward process
|
||
buffer = cnclep.buffer
|
||
self.dispatch_send_buffer = buffer.dispatch_send_token_tensor
|
||
self.dispatch_recv_buffer = buffer.dispatch_recv_token_tensor
|
||
self.combine_send_buffer = buffer.combine_send_token_tensor
|
||
self.combine_recv_buffer = buffer.combine_recv_token_tensor
|
||
self.max_num_tokens_per_rank = cnclep.max_num_tokens_per_rank
|
||
|
||
# get sizes in bytes
|
||
self.quant_size = self.hidden_size
|
||
self.scale_size = get_dtype_size(torch.float32)
|
||
self.dispatch_token_size = self.quant_size + self.scale_size
|
||
# [nranks, 2]
|
||
self.dispatch_recv_layout = torch.empty((self.moe_ep_size, 2), dtype=torch.int32, device="mlu")
|
||
# [num_total_experts]
|
||
self.dispatch_recv_token_num = torch.empty((self.num_total_experts), dtype=torch.int32, device="mlu")
|
||
|
||
self.max_num_tokens_recv = self.max_num_tokens_per_rank * self.moe_ep_size
|
||
self.max_num_tokens_per_expert = divide(self.max_num_tokens_recv, self.top_k)
|
||
|
||
quant_input_recv_size = self.max_num_tokens_recv * self.quant_size
|
||
input_scale_recv_size = self.max_num_tokens_recv * self.scale_size
|
||
self.quant_input_recv = (
|
||
self.combine_send_buffer[:quant_input_recv_size]
|
||
.view(self.max_num_tokens_recv, self.quant_size))
|
||
self.input_scale_recv = (
|
||
self.combine_send_buffer[quant_input_recv_size : quant_input_recv_size + input_scale_recv_size]
|
||
.view(self.max_num_tokens_recv, self.scale_size))
|
||
|
||
def forward_all2all(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
gate: ReplicatedLinear,
|
||
streams: Optional[Dict[str, torch.mlu.Stream]] = None,
|
||
shared_experts: Optional[nn.Module] = None,
|
||
) -> torch.Tensor:
|
||
"""forward with all2all."""
|
||
ori_input_shape = hidden_states.shape
|
||
dtype = hidden_states.dtype
|
||
self.pack_params()
|
||
self.pack_params_after_loading()
|
||
w1=self.w13
|
||
w2=self.w2
|
||
bias2=self.b2
|
||
input_smooth=self.a13_scale_all_experts
|
||
act_smooth=self.a2_scale
|
||
w1_scale=self.w13_scale
|
||
w2_scale=self.w2_scale
|
||
topk=self.top_k
|
||
renormalized=self.renormalize
|
||
act_mode=self.hidden_act
|
||
quant_input=None
|
||
|
||
start_expert_id=self.start_expert_id
|
||
expert_size = w1.size(0)
|
||
max_m = hidden_states.shape[0]
|
||
gating_output, _ = gate(hidden_states)
|
||
gating_output = gating_output.view(-1, gating_output.size(-1))
|
||
if self.scoring_func == "softmax":
|
||
reduce_weight, expert_id = mlu_ops.moe_softmax_topk(gating_output, topk, renormalized, self.expert_group,
|
||
self.topk_group, route_scale=self.routed_scaling_factor)
|
||
elif self.scoring_func == "sigmoid":
|
||
reduce_weight, expert_id = mlu_ops.moe_sigmoid_topk(gating_output, topk, renormalized,
|
||
self.expert_group, self.topk_group,
|
||
self.routed_scaling_factor,
|
||
self.gate.e_score_correction_bias)
|
||
else:
|
||
raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
|
||
|
||
if VLLM_AVG_MOE_EN:
|
||
# get dp rank
|
||
dp_rank = get_dp_group().rank_in_group
|
||
tp_rank = get_tp_group().rank_in_group
|
||
global_rank = dp_rank * get_tp_group().world_size + tp_rank
|
||
n_tokens = hidden_states.shape[0]
|
||
reduce_weight = SparseMoeMlp.reduce_weight[:n_tokens]
|
||
if self.use_all2all and VLLM_RANDOM_MOE_EN:
|
||
expert_id = SparseMoeMlp.expert_id[global_rank * n_tokens : (global_rank+1) * n_tokens]
|
||
elif self.use_all2all:
|
||
expert_id = SparseMoeMlp.expert_id[dp_rank * n_tokens: dp_rank * n_tokens + n_tokens]
|
||
else:
|
||
expert_id = SparseMoeMlp.expert_id[:n_tokens]
|
||
|
||
expand_idx, combine_idx, token_count, cusum_token_count \
|
||
= mlu_ops.moe_gen_idx(expert_id, self.num_total_experts)
|
||
|
||
num_token_expand = hidden_states.shape[0] * self.top_k
|
||
dispatch_bytes = num_token_expand * self.dispatch_token_size
|
||
|
||
dispatch_send_token_tensor = (
|
||
self.dispatch_send_buffer[:dispatch_bytes]
|
||
.view(num_token_expand, self.dispatch_token_size)
|
||
)
|
||
|
||
quant_size = self.hidden_size
|
||
quant_input = dispatch_send_token_tensor[:, : quant_size]
|
||
input_scale = dispatch_send_token_tensor[:, quant_size :].view(torch.float32)
|
||
quant_input, input_scale = mlu_ops.moe_quantize(
|
||
hidden_states, input_smooth, None, token_count, expand_idx, None,
|
||
output=quant_input,
|
||
output_scale=input_scale)
|
||
|
||
dispatch_send_layout = mlu_ops.moe_all2all_gen_send_layout(token_count, self.moe_ep_size)
|
||
|
||
cnclep_dispatch(self.dispatch_token_size,
|
||
num_token_expand,
|
||
dispatch_send_layout,
|
||
token_count,
|
||
self.dispatch_recv_layout,
|
||
self.dispatch_recv_token_num)
|
||
|
||
recv_token_num = self.dispatch_recv_token_num.view(self.moe_ep_size, self.num_experts_per_rank)
|
||
pad_num = self.max_num_tokens_per_rank
|
||
|
||
(
|
||
gather_by_expert_index,
|
||
gather_by_rank_index,
|
||
tokens_per_local_expert,
|
||
token_sum
|
||
) = mlu_ops.moe_all2all_gen_gather_index(recv_token_num, pad_num)
|
||
|
||
max_tokens_bytes_recv = self.max_num_tokens_recv * self.dispatch_token_size
|
||
dispatch_recv_token_tensor = (
|
||
self.dispatch_recv_buffer[:max_tokens_bytes_recv]
|
||
.view(self.max_num_tokens_recv, self.dispatch_token_size))
|
||
|
||
mlu_ops.gather_split(dispatch_recv_token_tensor,
|
||
gather_by_expert_index,
|
||
token_sum,
|
||
self.quant_input_recv,
|
||
self.input_scale_recv)
|
||
|
||
max_m = self.max_num_tokens_per_expert
|
||
gemm_out = mlu_ops.smooth_quant_group_gemm(self.quant_input_recv, w1,
|
||
tokens_per_local_expert,
|
||
None, None, None, None,
|
||
self.input_scale_recv.view(torch.float32).flatten(),
|
||
w1_scale, dtype, max_m)
|
||
|
||
# continue reusing self.quant_input_recv and self.input_scale_recv
|
||
quant_input = self.quant_input_recv[:, :gemm_out.shape[-1] // 2]
|
||
input_scale_fp32 = self.input_scale_recv.view(torch.float32).flatten()[:gemm_out.shape[0]]
|
||
quant_input, input_scale = mlu_ops.moe_quantize(gemm_out, act_smooth, None,
|
||
tokens_per_local_expert,
|
||
output=quant_input,
|
||
output_scale=input_scale_fp32,
|
||
act_mode=act_mode,
|
||
is_gated=self.is_gated)
|
||
|
||
gemm_out = mlu_ops.smooth_quant_group_gemm(quant_input, w2,
|
||
tokens_per_local_expert,
|
||
None, None, None, None, input_scale, w2_scale, dtype, max_m)
|
||
|
||
combine_send_token_tensor = self.combine_send_buffer.view(self.max_num_tokens_recv, -1).view(hidden_states.dtype)
|
||
mlu_ops.gather_split(gemm_out,
|
||
gather_by_rank_index,
|
||
token_sum,
|
||
combine_send_token_tensor,
|
||
None)
|
||
|
||
combine_send_layout = mlu_ops.moe_all2all_gen_send_layout(self.dispatch_recv_token_num, self.moe_ep_size)
|
||
combine_recv_layout = self.dispatch_recv_layout
|
||
|
||
# combine
|
||
combine_args = dict(
|
||
token_byte=self.hidden_size * 2,
|
||
token_num=num_token_expand,
|
||
send_src_layout=combine_send_layout,
|
||
send_dst_layout=combine_recv_layout,
|
||
send_token=None,
|
||
recv_token=None)
|
||
|
||
shared_output = None
|
||
if shared_experts is not None:
|
||
parallelize_shared_expert = streams is not None
|
||
if parallelize_shared_expert:
|
||
compute_stream = streams['shared']
|
||
comm_stream = streams['routed']
|
||
curr_stream = torch.mlu.current_stream()
|
||
compute_stream.wait_stream(curr_stream)
|
||
comm_stream.wait_stream(curr_stream)
|
||
|
||
with torch.mlu.stream(compute_stream):
|
||
shared_output = shared_experts(hidden_states, use_tp_weight=False)
|
||
|
||
with torch.mlu.stream(comm_stream):
|
||
cnclep_combine(**combine_args)
|
||
|
||
curr_stream.wait_stream(compute_stream)
|
||
curr_stream.wait_stream(comm_stream)
|
||
else:
|
||
shared_output = shared_experts(hidden_states, use_tp_weight=False)
|
||
cnclep_combine(**combine_args)
|
||
else:
|
||
cnclep_combine(**combine_args)
|
||
|
||
numel_recv = num_token_expand * self.hidden_size
|
||
recv_token = (self.combine_recv_buffer.view(hidden_states.dtype)[:numel_recv]
|
||
.view(num_token_expand, self.hidden_size))
|
||
|
||
residual_ = shared_output
|
||
output = mlu_ops.moe_combine_result(recv_token, reduce_weight, combine_idx,
|
||
residual_, None, start_expert_id,
|
||
expert_size, bias2, output=hidden_states)
|
||
|
||
return output.view(ori_input_shape)
|