add 2 kernels and optimize the calculation of topk_indices (#134)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-22 10:29:28 +08:00
committed by GitHub
parent c9f00c132c
commit 74d4f804e8
4 changed files with 108 additions and 32 deletions

View File

@@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
del y
out1 = out1.reshape(-1, out1.shape[-1])
x_shape = out1.shape
@@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
)
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
del out1, moe_expand
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C.moe_fc(
x=x_q,
@@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
# sort_mode=False,
act=None,
)
del x_q, x_scale, sorted_tokens_num_lod,expert_m
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
output = torch.empty(