Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)

Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
Xinyu Dong
2026-02-12 18:13:00 +08:00
committed by GitHub
parent 744719587e
commit bf9369f733
15 changed files with 125 additions and 119 deletions

View File

@@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
import xtorch_ops
import kunlun_ops
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
) * q_len
sorted_tokens_idx = torch.arange(
self.num_heads * q_len, dtype=torch.int, device="cuda")
xtorch_ops.mla_bmm_I8(
kunlun_ops.mla_bmm_I8(
x.contiguous(), # [1, 16, 512] torch.float16
self.W_UV, # [16, 128, 512] torch.int8
self.W_UV_SCALE, # [2048, 1] torch.float32
@@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
tp_q_head_num=q.size(1)
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
softmax_lse.fill_(float('-inf'))
xtorch_ops.attention(
kunlun_ops.attention(
q=q,
k_cache=k,
v_cache=maybe_padded_v,
@@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
dtype=torch.float, device=kv_b_proj_weight.device)
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
self.W_UV = W_UV.contiguous()
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
else:
@@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
xtorch_ops.concat_and_cache_mla(
kunlun_ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
attn_metadata.slot_mapping.flatten(),
@@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
sorted_tokens_idx = torch.arange(
self.num_heads * q_len, dtype=torch.int, device="cuda")
extra_params = {"trans": False}
xtorch_ops.mla_bmm_I8(
kunlun_ops.mla_bmm_I8(
decode_q_nope.contiguous(),
self.W_UK_T,
self.W_UK_SCALE,