Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)

Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
Xinyu Dong
2026-02-12 18:13:00 +08:00
committed by GitHub
parent 744719587e
commit bf9369f733
15 changed files with 125 additions and 119 deletions

View File

@@ -28,7 +28,7 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import xtorch_ops
import kunlun_ops
logger.info(f"Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
@@ -71,7 +71,7 @@ class KunlunOps:
):
""" PagedAttentionV1 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
@@ -114,7 +114,7 @@ class KunlunOps:
):
""" PagedAttentionV2 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
@@ -133,7 +133,7 @@ class KunlunOps:
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
xtorch_ops.silu_and_mul(
kunlun_ops.silu_and_mul(
x,
axis=-1,
turn=True,
@@ -145,7 +145,7 @@ class KunlunOps:
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
xtorch_ops.quick_gelu(
kunlun_ops.quick_gelu(
x,
out=out,
)
@@ -159,7 +159,7 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(
kunlun_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
@@ -172,7 +172,7 @@ class KunlunOps:
):
"""fused_add_rms_norm"""
output = torch.empty_like(x)
xtorch_ops.add_rmsnorm(
kunlun_ops.add_rmsnorm(
x, residual, weight.to(torch.float32), epsilon, out=output
)
fused_input = x + residual
@@ -222,7 +222,7 @@ class KunlunOps:
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
kunlun_ops.mrotary_embedding_neox(
positions,
query_x,
key_x,
@@ -240,7 +240,7 @@ class KunlunOps:
dst,
block_mapping):
""" swap_blocks """
xtorch_ops.swap_blocks(
kunlun_ops.swap_blocks(
src,
dst,
block_mapping
@@ -255,7 +255,7 @@ class KunlunOps:
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
xtorch_ops.copy_blocks(
kunlun_ops.copy_blocks(
key_caches,
value_caches,
block_mapping,
@@ -272,7 +272,7 @@ class KunlunOps:
):
""" reshape_and_cache """
# slot_mapping_cast = slot_mapping.to(torch.int32)
xtorch_ops.reshape_and_cache(
kunlun_ops.reshape_and_cache(
key,
value,
key_cache,
@@ -308,7 +308,7 @@ class KunlunOps:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
kunlun_ops.attention(
q=query,
k_cache=key,
v_cache=value,
@@ -337,7 +337,7 @@ class KunlunOps:
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
if residual is None:
@@ -360,7 +360,7 @@ class KunlunOps:
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
return out, out_scale
@@ -388,7 +388,7 @@ class KunlunOps:
dtype=torch.float16,
device=weight.device)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
@@ -642,7 +642,7 @@ class KunlunOps:
"""mla pa block"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
xtorch_ops.xft_multi_head_latent_page_attention_block(
kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
kv_lora_rank,
@@ -688,7 +688,7 @@ class KunlunOps:
threshold: float = 20.0,
) -> torch.Tensor:
"""fused_gdn_gating"""
output = xtorch_ops.fused_gdn_gating(
output = kunlun_ops.fused_gdn_gating(
A_log,
a,
dt_bias,
@@ -713,7 +713,7 @@ class KunlunOps:
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
'''
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)

View File

@@ -93,7 +93,7 @@ class SiluAndMul(CustomOp):
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""forward_cuda"""
import xtorch_ops
import kunlun_ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
@@ -103,7 +103,7 @@ class SiluAndMul(CustomOp):
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
"""forward_kunlun"""
import xtorch_ops
import kunlun_ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
@@ -251,14 +251,14 @@ class GeluAndMul(CustomOp):
无。
"""
# from vllm import _custom_ops as ops
import xtorch_ops
import kunlun_ops
# d = x.shape[-1] // 2
# output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(x, dtype=x.dtype, device=x.device)
if self.approximate == "none":
# ops.gelu_and_mul(out, x)
print(x,x.shape)
xtorch_ops.gelu(x, out)
kunlun_ops.gelu(x, out)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out

View File

@@ -7,7 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
import xtorch_ops
import kunlun_ops
logger = init_logger(__name__)
@@ -104,7 +104,7 @@ def flash_mla_with_kvcache(
is_context = False
vo_head_dim = -1
xtorch_ops.paged_attention(out,
kunlun_ops.paged_attention(out,
q,
k_cache, None,
block_table,
@@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache(
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
"""
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:

View File

@@ -3,7 +3,7 @@
from typing import Optional
import torch
import xtorch_ops
import kunlun_ops
from vllm.platforms import current_platform
@@ -16,7 +16,7 @@ def merge_attn_states(
output_lse: Optional[torch.Tensor] = None,
) -> None:
return xtorch_ops.attention_merge_stage(
return kunlun_ops.attention_merge_stage(
prefix_output,
prefix_lse,
suffix_output,

View File

@@ -11,7 +11,7 @@ from typing import Optional
import torch
import xtorch_ops
import kunlun_ops
class FusedRecurrentFunction(torch.autograd.Function):
@@ -31,7 +31,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
q.contiguous(),
k.contiguous(),
v.contiguous(),

View File

@@ -13,7 +13,7 @@ from typing import Optional
import torch
from vllm.triton_utils import tl, triton
import xtorch_ops
import kunlun_ops
BT_LIST = [8, 16, 32, 64, 128]
@@ -149,5 +149,5 @@ def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
out = torch.empty_like(x)
xtorch_ops.l2norm(x, out, eps)
kunlun_ops.l2norm(x, out, eps)
return out

View File

@@ -21,7 +21,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers import layernorm
from typing import Optional, Union
import xtorch_ops
import kunlun_ops
def vllm_kunlun_forward_cuda(
self,

View File

@@ -12,7 +12,7 @@ import torch.nn.functional as F
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
import xtorch_ops
import kunlun_ops
@triton.jit()
@@ -1212,7 +1212,7 @@ def torch_causal_conv1d_update(
tmp_hidden_states = hidden_states_new[:, :, -state_len:]
ori_shape = tmp_hidden_states.shape
tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape)
xtorch_ops.reshape_and_cache_flash(
kunlun_ops.reshape_and_cache_flash(
tmp_hidden_states,
tmp_hidden_states,
cast_conv_state,

View File

@@ -113,7 +113,7 @@ class KunlunCompressedTensorsMoEMethod(FusedMoEMethodBase):
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# NOTE: xtorch_ops use max as scale
# NOTE: kunlun_ops use max as scale
with torch.no_grad():
layer.w13_weight_scale.mul_(127.0)
layer.w2_weight_scale.mul_(127.0)