[Feature] Support XiaoMi MIMO Flash V2 (#62)

* [Feature] Support MIMO Flash V2
This commit is contained in:
Xinyu Dong
2025-12-31 10:16:33 +08:00
committed by GitHub
parent 341dc7f296
commit b3c30a3cb9
12 changed files with 1530 additions and 690 deletions

View File

@@ -938,6 +938,83 @@ def _fake_rotary_embedding(
rotary_embedding.register_fake(_fake_rotary_embedding)
@custom_op("_C::quant2d", mutates_args=())
def quant2d(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
xtorch_ops.quant2d(
x=x,
y=y,
max=max,
force_sdnn=force_sdnn
)
@impl("_C::quant2d", "CUDA")
def quant2d_cuda(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
xtorch_ops.quant2d(
x=x,
y=y,
max=max,
force_sdnn=force_sdnn
)
def _fake_quant2d(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
return None
quant2d.register_fake(_fake_quant2d)
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
def gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale),
rhs=(weight, weight_scale),
out=out
)
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
def gemm_I8_I8_bf16_nt_cuda(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale),
rhs=(weight, weight_scale),
out=out
)
def _fake_gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
return None
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
def moe_softmax_topk_norm(
x: torch.Tensor,
@@ -1068,15 +1145,39 @@ def moe_fc(
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
xtorch_ops.moe_fc(
x,
weight,
sorted_tokens_num_lod,
sorted_tokens_idx,
moe_topk,
y)
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode)
@impl("_C::moe_fc", "CUDA")
def moe_fc_cuda(
@@ -1085,15 +1186,39 @@ def moe_fc_cuda(
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
xtorch_ops.moe_fc(
x,
weight,
sorted_tokens_num_lod,
sorted_tokens_idx,
moe_topk,
y)
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode)
def fake_moe_fc(
x: torch.Tensor,
@@ -1101,7 +1226,19 @@ def fake_moe_fc(
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
return None
@@ -1151,6 +1288,63 @@ def fake_moe_post(
moe_post.register_fake(fake_moe_post)
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
def moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
def moe_sigmoid_group_topk_norm_cuda(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
def _fake_moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
return None
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
##################################################
# --------------- awq_dequantize -----------------
##################################################