add 2 kernels and optimize the calculation of topk_indices (#134)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -153,13 +153,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(
|
||||
M,
|
||||
top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
del y
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
x_shape = out1.shape
|
||||
@@ -168,6 +162,14 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
||||
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
|
||||
)
|
||||
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
||||
del out1, moe_expand
|
||||
out = torch.empty(
|
||||
M,
|
||||
top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
@@ -182,6 +184,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
||||
# sort_mode=False,
|
||||
act=None,
|
||||
)
|
||||
del x_q, x_scale, sorted_tokens_num_lod,expert_m
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
|
||||
output = torch.empty(
|
||||
|
||||
Reference in New Issue
Block a user