[Feature] Support XiaoMi MIMO Flash V2 (#62)

* [Feature] Support MIMO Flash V2
This commit is contained in:
Xinyu Dong
2025-12-31 10:16:33 +08:00
committed by GitHub
parent 341dc7f296
commit b3c30a3cb9
12 changed files with 1530 additions and 690 deletions

View File

@@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
linear_weights=linear_weights)
linear_weights=linear_weights,
e_score_correction_bias=e_score_correction_bias)
def forward_kunlun(
self,
@@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
@@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
num_expert_group=num_expert_group,
topk_group=topk_group
)
# fused_moe do not support expert number > 400
elif layer.local_num_experts > 400:
hidden_states = x
global_num_experts = linear_weights.shape[0]
M, N = hidden_states.shape
hidden_dim = layer.w2_weight.shape[1]
normed_score = torch.empty(M,
top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
top_k,
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
router_logits = router_logits.float()
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=layer.w13_weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.swiglu(y, out1)
out = torch.empty(M,top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=layer.w2_weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=out)
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
else:
return ops.fused_moe(x,
layer.w13_weight,
@@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
class FusedMoE(VllmFusedMoE):