[Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
This commit is contained in:
@@ -938,6 +938,83 @@ def _fake_rotary_embedding(
|
||||
|
||||
rotary_embedding.register_fake(_fake_rotary_embedding)
|
||||
|
||||
@custom_op("_C::quant2d", mutates_args=())
|
||||
def quant2d(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
x=x,
|
||||
y=y,
|
||||
max=max,
|
||||
force_sdnn=force_sdnn
|
||||
)
|
||||
|
||||
@impl("_C::quant2d", "CUDA")
|
||||
def quant2d_cuda(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
x=x,
|
||||
y=y,
|
||||
max=max,
|
||||
force_sdnn=force_sdnn
|
||||
)
|
||||
|
||||
def _fake_quant2d(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
quant2d.register_fake(_fake_quant2d)
|
||||
|
||||
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
|
||||
def gemm_I8_I8_bf16_nt(
|
||||
x_q: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale),
|
||||
rhs=(weight, weight_scale),
|
||||
out=out
|
||||
)
|
||||
|
||||
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
|
||||
def gemm_I8_I8_bf16_nt_cuda(
|
||||
x_q: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale),
|
||||
rhs=(weight, weight_scale),
|
||||
out=out
|
||||
)
|
||||
|
||||
def _fake_gemm_I8_I8_bf16_nt(
|
||||
x_q: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
|
||||
|
||||
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
|
||||
def moe_softmax_topk_norm(
|
||||
x: torch.Tensor,
|
||||
@@ -1068,15 +1145,39 @@ def moe_fc(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
sorted_tokens_idx: torch.Tensor,
|
||||
moe_topk: int,
|
||||
y: torch.Tensor
|
||||
y: torch.Tensor,
|
||||
act: Optional[str] = None,
|
||||
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
topk_w: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tgemm_type: Optional[str] = None,
|
||||
tweight_type: Optional[str] = None,
|
||||
scale_n: Optional[int] = 0,
|
||||
scale_k: Optional[int] = 0,
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True
|
||||
)-> None:
|
||||
xtorch_ops.moe_fc(
|
||||
x,
|
||||
weight,
|
||||
sorted_tokens_num_lod,
|
||||
sorted_tokens_idx,
|
||||
moe_topk,
|
||||
y)
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_topk,
|
||||
y=y,
|
||||
act=act,
|
||||
x_perchannel_max=x_perchannel_max,
|
||||
w_perchannel_max=w_perchannel_max,
|
||||
topk_ids=topk_ids,
|
||||
topk_w=topk_w,
|
||||
bias=bias,
|
||||
tgemm_type=tgemm_type,
|
||||
tweight_type=tweight_type,
|
||||
scale_n=scale_n,
|
||||
scale_k=scale_k,
|
||||
use_pack_int4=use_pack_int4,
|
||||
sort_mode=sort_mode)
|
||||
|
||||
@impl("_C::moe_fc", "CUDA")
|
||||
def moe_fc_cuda(
|
||||
@@ -1085,15 +1186,39 @@ def moe_fc_cuda(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
sorted_tokens_idx: torch.Tensor,
|
||||
moe_topk: int,
|
||||
y: torch.Tensor
|
||||
y: torch.Tensor,
|
||||
act: Optional[str] = None,
|
||||
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
topk_w: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tgemm_type: Optional[str] = None,
|
||||
tweight_type: Optional[str] = None,
|
||||
scale_n: Optional[int] = 0,
|
||||
scale_k: Optional[int] = 0,
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True
|
||||
)-> None:
|
||||
xtorch_ops.moe_fc(
|
||||
x,
|
||||
weight,
|
||||
sorted_tokens_num_lod,
|
||||
sorted_tokens_idx,
|
||||
moe_topk,
|
||||
y)
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_topk,
|
||||
y=y,
|
||||
act=act,
|
||||
x_perchannel_max=x_perchannel_max,
|
||||
w_perchannel_max=w_perchannel_max,
|
||||
topk_ids=topk_ids,
|
||||
topk_w=topk_w,
|
||||
bias=bias,
|
||||
tgemm_type=tgemm_type,
|
||||
tweight_type=tweight_type,
|
||||
scale_n=scale_n,
|
||||
scale_k=scale_k,
|
||||
use_pack_int4=use_pack_int4,
|
||||
sort_mode=sort_mode)
|
||||
|
||||
def fake_moe_fc(
|
||||
x: torch.Tensor,
|
||||
@@ -1101,7 +1226,19 @@ def fake_moe_fc(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
sorted_tokens_idx: torch.Tensor,
|
||||
moe_topk: int,
|
||||
y: torch.Tensor
|
||||
y: torch.Tensor,
|
||||
act: Optional[str] = None,
|
||||
x_perchannel_max: Optional[torch.Tensor] = None,
|
||||
w_perchannel_max: Optional[torch.Tensor] = None ,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
topk_w: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tgemm_type: Optional[str] = None,
|
||||
tweight_type: Optional[str] = None,
|
||||
scale_n: Optional[int] = 0,
|
||||
scale_k: Optional[int] = 0,
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True
|
||||
)-> None:
|
||||
return None
|
||||
|
||||
@@ -1151,6 +1288,63 @@ def fake_moe_post(
|
||||
moe_post.register_fake(fake_moe_post)
|
||||
|
||||
|
||||
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
|
||||
def moe_sigmoid_group_topk_norm(
|
||||
x: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
norm_score: torch.Tensor,
|
||||
block_static: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
scale: float,
|
||||
n_group: int,
|
||||
topk_group: int
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
block_static=block_static,
|
||||
bias=bias,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
|
||||
def moe_sigmoid_group_topk_norm_cuda(
|
||||
x: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
norm_score: torch.Tensor,
|
||||
block_static: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
scale: float,
|
||||
n_group: int,
|
||||
topk_group: int
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
block_static=block_static,
|
||||
bias=bias,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
def _fake_moe_sigmoid_group_topk_norm(
|
||||
x: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
norm_score: torch.Tensor,
|
||||
block_static: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
scale: float,
|
||||
n_group: int,
|
||||
topk_group: int
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
|
||||
##################################################
|
||||
# --------------- awq_dequantize -----------------
|
||||
##################################################
|
||||
|
||||
Reference in New Issue
Block a user