Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -405,7 +405,7 @@ def add_rmsnorm(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.add_rmsnorm(
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x,
|
||||
y, # 原来写 residual,这里其实是 y
|
||||
residual_output=residual_output,
|
||||
@@ -429,7 +429,7 @@ def add_rmsnorm_cuda(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.add_rmsnorm(
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
residual_output=residual_output,
|
||||
@@ -451,7 +451,7 @@ def rmsnorm(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.rmsnorm(
|
||||
kunlun_ops.rmsnorm(
|
||||
x,
|
||||
weight,
|
||||
output,
|
||||
@@ -471,7 +471,7 @@ def rmsnorm_cuda(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.rmsnorm(
|
||||
kunlun_ops.rmsnorm(
|
||||
x,
|
||||
weight,
|
||||
output,
|
||||
@@ -541,7 +541,7 @@ def split_norm_rope_neox(
|
||||
rotary_dim: int,
|
||||
emb_batch_size: int = 1,
|
||||
) -> None:
|
||||
xtorch_ops.split_norm_rope_neox(
|
||||
kunlun_ops.split_norm_rope_neox(
|
||||
q_emb,
|
||||
k_emb,
|
||||
v_out,
|
||||
@@ -577,7 +577,7 @@ def split_norm_rope_neox_cuda(
|
||||
rotary_dim: int,
|
||||
emb_batch_size: int = 1,
|
||||
) -> None:
|
||||
xtorch_ops.split_norm_rope_neox(
|
||||
kunlun_ops.split_norm_rope_neox(
|
||||
q_emb,
|
||||
k_emb,
|
||||
v_out,
|
||||
@@ -649,7 +649,7 @@ if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
||||
def silu_and_mul(
|
||||
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
||||
) -> None:
|
||||
xtorch_ops.swiglu(
|
||||
kunlun_ops.swiglu(
|
||||
x=x,
|
||||
y=out,
|
||||
)
|
||||
@@ -659,7 +659,7 @@ def silu_and_mul(
|
||||
def silu_and_mul_cuda(
|
||||
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
||||
) -> None:
|
||||
xtorch_ops.swiglu(
|
||||
kunlun_ops.swiglu(
|
||||
x=x,
|
||||
y=out,
|
||||
)
|
||||
@@ -736,7 +736,7 @@ def moe_softmax_topk(
|
||||
axis: int = -1,
|
||||
turn: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
|
||||
|
||||
@impl("_C::moe_softmax_topk", "CUDA")
|
||||
@@ -748,7 +748,7 @@ def moe_softmax_topk_cuda(
|
||||
axis: int = -1,
|
||||
turn: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
|
||||
|
||||
def _fake_moe_softmax_topk(
|
||||
@@ -781,7 +781,7 @@ def moe_ffn_block(
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_block(
|
||||
kunlun_ops.moe_ffn_block(
|
||||
x=x,
|
||||
gate_w=gate_w,
|
||||
inter_w=inter_w,
|
||||
@@ -812,7 +812,7 @@ def moe_ffn_block_cuda(
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_block(
|
||||
kunlun_ops.moe_ffn_block(
|
||||
x=x,
|
||||
gate_w=gate_w,
|
||||
inter_w=inter_w,
|
||||
@@ -863,7 +863,7 @@ def moe_ffn_per_token_block(
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_per_token_block(
|
||||
kunlun_ops.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=inter_weight,
|
||||
inter_scale=inter_scale,
|
||||
@@ -897,7 +897,7 @@ def moe_ffn_per_token_block_cuda(
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_per_token_block(
|
||||
kunlun_ops.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=inter_weight,
|
||||
inter_scale=inter_scale,
|
||||
@@ -948,7 +948,7 @@ def rotary_embedding(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
xtorch_ops.rotary_embedding(
|
||||
kunlun_ops.rotary_embedding(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -967,7 +967,7 @@ def rotary_embedding_cuda(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
xtorch_ops.rotary_embedding(
|
||||
kunlun_ops.rotary_embedding(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -999,7 +999,7 @@ def gemm_I8_I8_bf16_nt(
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
kunlun_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
||||
)
|
||||
|
||||
@@ -1012,7 +1012,7 @@ def gemm_I8_I8_bf16_nt_cuda(
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
kunlun_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
||||
)
|
||||
|
||||
@@ -1038,7 +1038,7 @@ def moe_softmax_topk_norm(
|
||||
block_statistic: torch.Tensor,
|
||||
stable: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk_norm(
|
||||
kunlun_ops.moe_softmax_topk_norm(
|
||||
x, normed_score, topk_index, block_statistic, stable
|
||||
)
|
||||
|
||||
@@ -1051,7 +1051,7 @@ def moe_softmax_topk_norm_cuda(
|
||||
block_statistic: torch.Tensor,
|
||||
stable: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk_norm(
|
||||
kunlun_ops.moe_softmax_topk_norm(
|
||||
x, normed_score, topk_index, block_statistic, stable
|
||||
)
|
||||
|
||||
@@ -1071,14 +1071,14 @@ moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
|
||||
|
||||
@custom_op("_C::gen_block_statistic", mutates_args=())
|
||||
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
|
||||
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
|
||||
|
||||
@impl("_C::gen_block_statistic", "CUDA")
|
||||
def gen_block_statistic_cuda(
|
||||
topk_ids: torch.Tensor, block_statistic: torch.Tensor
|
||||
) -> None:
|
||||
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
|
||||
|
||||
def fake_gen_block_statistic(
|
||||
@@ -1101,7 +1101,7 @@ def moe_pre_sorted(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
index_have_neg: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.moe_pre_sorted(
|
||||
kunlun_ops.moe_pre_sorted(
|
||||
x,
|
||||
topk_index,
|
||||
block_statistic,
|
||||
@@ -1123,7 +1123,7 @@ def moe_pre_sorted_cuda(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
index_have_neg: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.moe_pre_sorted(
|
||||
kunlun_ops.moe_pre_sorted(
|
||||
x,
|
||||
topk_index,
|
||||
block_statistic,
|
||||
@@ -1171,7 +1171,7 @@ def moe_fc(
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_fc(
|
||||
kunlun_ops.moe_fc(
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
@@ -1214,7 +1214,7 @@ def moe_fc_cuda(
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_fc(
|
||||
kunlun_ops.moe_fc(
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
@@ -1270,7 +1270,7 @@ def moe_post(
|
||||
dequant_scale: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
|
||||
|
||||
@impl("_C::moe_post", "CUDA")
|
||||
@@ -1281,7 +1281,7 @@ def moe_post_cuda(
|
||||
dequant_scale: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
|
||||
|
||||
def fake_moe_post(
|
||||
@@ -1308,7 +1308,7 @@ def moe_sigmoid_group_topk_norm(
|
||||
n_group: int,
|
||||
topk_group: int,
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
kunlun_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
@@ -1331,7 +1331,7 @@ def moe_sigmoid_group_topk_norm_cuda(
|
||||
n_group: int,
|
||||
topk_group: int,
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
kunlun_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
@@ -1376,7 +1376,7 @@ def awq_dequantize(
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
xtorch_ops.awq_dequantize(
|
||||
kunlun_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
@@ -1402,7 +1402,7 @@ def awq_dequantize_cuda(
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
xtorch_ops.awq_dequantize(
|
||||
out = kunlun_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
@@ -1447,7 +1447,7 @@ def awq_gemm(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||
xtorch_ops.awq_gemm(
|
||||
kunlun_ops.awq_gemm(
|
||||
x=x,
|
||||
w=qweight,
|
||||
scale=scale,
|
||||
@@ -1471,7 +1471,7 @@ def awq_gemm_cuda(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||
xtorch_ops.awq_gemm(
|
||||
kunlun_ops.awq_gemm(
|
||||
x=x,
|
||||
w=qweight,
|
||||
scale=scale,
|
||||
@@ -1508,7 +1508,7 @@ def gptq_shuffle(
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
|
||||
|
||||
@impl("_C::gptq_shuffle", "CUDA")
|
||||
@@ -1517,7 +1517,7 @@ def gptq_shuffle_cuda(
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
|
||||
|
||||
def _fake_gptq_shuffle(
|
||||
@@ -1541,7 +1541,7 @@ def concat_and_cache_mla(
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
@@ -1556,7 +1556,7 @@ def concat_and_cache_mla_cuda(
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
@@ -1598,7 +1598,7 @@ def scaled_int8_quant(
|
||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||
if symmetric:
|
||||
# NOTE: For quant2d ops, scale represents max.
|
||||
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
else:
|
||||
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
||||
x_q, x.contiguous(), scale, azp
|
||||
@@ -1625,7 +1625,7 @@ def scaled_int8_quant_cuda(
|
||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||
if symmetric:
|
||||
# NOTE: For quant2d ops, scale represents max.
|
||||
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
else:
|
||||
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
||||
x_q, x.contiguous(), scale, azp
|
||||
@@ -1777,7 +1777,7 @@ def matmul(
|
||||
dtype=out_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=out,
|
||||
@@ -1814,7 +1814,7 @@ def matmul_cuda(
|
||||
dtype=out_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=out,
|
||||
@@ -1865,7 +1865,7 @@ def quant2d(
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
kunlun_ops.quant2d(
|
||||
x=x,
|
||||
y=x_q,
|
||||
max=max,
|
||||
@@ -1880,7 +1880,7 @@ def quant2d_cuda(
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
kunlun_ops.quant2d(
|
||||
x=x,
|
||||
y=x_q,
|
||||
max=max,
|
||||
@@ -1954,7 +1954,7 @@ def I8_mqa_logits(
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
kunlun_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -1984,7 +1984,7 @@ def I8_mqa_logits_cuda(
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
kunlun_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2034,7 +2034,8 @@ def I8_paged_mqa_logits(
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
kunlun_ops.sparse_prefill_fwd_opt(
|
||||
.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2060,7 +2061,7 @@ def I8_paged_mqa_logits_cuda(
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
kunlun_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2111,7 +2112,7 @@ def sparse_prefill_fwd_opt(
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
kunlun_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
@@ -2147,7 +2148,7 @@ def sparse_prefill_fwd_opt_cuda(
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
kunlun_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
@@ -2207,7 +2208,7 @@ def fwd_kvcache_mla(
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
kunlun_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
indices=indices,
|
||||
@@ -2241,7 +2242,7 @@ def fwd_kvcache_mla_cuda(
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
kunlun_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
indices=indices,
|
||||
@@ -2293,7 +2294,7 @@ def dequant_int4(
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.dequant_int4(
|
||||
kunlun_ops.dequant_int4(
|
||||
x=x,
|
||||
scale=scale,
|
||||
zero=zero,
|
||||
@@ -2315,7 +2316,7 @@ def dequant_int4_cuda(
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.dequant_int4(
|
||||
kunlun_ops.dequant_int4(
|
||||
x=x,
|
||||
scale=scale,
|
||||
zero=zero,
|
||||
@@ -2350,7 +2351,10 @@ def fast_topkv2(
|
||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||
) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
topk_indices = kunlun_ops.fast_topkv2(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk=topk)
|
||||
return topk_indices
|
||||
|
||||
|
||||
@@ -2359,7 +2363,10 @@ def fast_topkv2_cuda(
|
||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||
) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
topk_indices = kunlun_ops.fast_topkv2(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk=topk)
|
||||
return topk_indices
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user