[Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
This commit is contained in:
@@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights)
|
||||
linear_weights=linear_weights,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
@@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
@@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
# fused_moe do not support expert number > 400
|
||||
elif layer.local_num_experts > 400:
|
||||
hidden_states = x
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
router_logits = router_logits.float()
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=layer.w13_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=layer.w2_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
@@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
|
||||
Reference in New Issue
Block a user