Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)

Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
Xinyu Dong
2026-02-12 18:13:00 +08:00
committed by GitHub
parent 744719587e
commit bf9369f733
15 changed files with 125 additions and 119 deletions

View File

@@ -28,7 +28,7 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import xtorch_ops
import kunlun_ops
logger.info(f"Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
@@ -71,7 +71,7 @@ class KunlunOps:
):
""" PagedAttentionV1 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
@@ -114,7 +114,7 @@ class KunlunOps:
):
""" PagedAttentionV2 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
@@ -133,7 +133,7 @@ class KunlunOps:
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
xtorch_ops.silu_and_mul(
kunlun_ops.silu_and_mul(
x,
axis=-1,
turn=True,
@@ -145,7 +145,7 @@ class KunlunOps:
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
xtorch_ops.quick_gelu(
kunlun_ops.quick_gelu(
x,
out=out,
)
@@ -159,7 +159,7 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(
kunlun_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
@@ -172,7 +172,7 @@ class KunlunOps:
):
"""fused_add_rms_norm"""
output = torch.empty_like(x)
xtorch_ops.add_rmsnorm(
kunlun_ops.add_rmsnorm(
x, residual, weight.to(torch.float32), epsilon, out=output
)
fused_input = x + residual
@@ -222,7 +222,7 @@ class KunlunOps:
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
kunlun_ops.mrotary_embedding_neox(
positions,
query_x,
key_x,
@@ -240,7 +240,7 @@ class KunlunOps:
dst,
block_mapping):
""" swap_blocks """
xtorch_ops.swap_blocks(
kunlun_ops.swap_blocks(
src,
dst,
block_mapping
@@ -255,7 +255,7 @@ class KunlunOps:
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
xtorch_ops.copy_blocks(
kunlun_ops.copy_blocks(
key_caches,
value_caches,
block_mapping,
@@ -272,7 +272,7 @@ class KunlunOps:
):
""" reshape_and_cache """
# slot_mapping_cast = slot_mapping.to(torch.int32)
xtorch_ops.reshape_and_cache(
kunlun_ops.reshape_and_cache(
key,
value,
key_cache,
@@ -308,7 +308,7 @@ class KunlunOps:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
kunlun_ops.attention(
q=query,
k_cache=key,
v_cache=value,
@@ -337,7 +337,7 @@ class KunlunOps:
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
if residual is None:
@@ -360,7 +360,7 @@ class KunlunOps:
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
return out, out_scale
@@ -388,7 +388,7 @@ class KunlunOps:
dtype=torch.float16,
device=weight.device)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
@@ -642,7 +642,7 @@ class KunlunOps:
"""mla pa block"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
xtorch_ops.xft_multi_head_latent_page_attention_block(
kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
kv_lora_rank,
@@ -688,7 +688,7 @@ class KunlunOps:
threshold: float = 20.0,
) -> torch.Tensor:
"""fused_gdn_gating"""
output = xtorch_ops.fused_gdn_gating(
output = kunlun_ops.fused_gdn_gating(
A_log,
a,
dt_bias,
@@ -713,7 +713,7 @@ class KunlunOps:
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
'''
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)