Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
) * q_len
|
||||
sorted_tokens_idx = torch.arange(
|
||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||
xtorch_ops.mla_bmm_I8(
|
||||
kunlun_ops.mla_bmm_I8(
|
||||
x.contiguous(), # [1, 16, 512] torch.float16
|
||||
self.W_UV, # [16, 128, 512] torch.int8
|
||||
self.W_UV_SCALE, # [2048, 1] torch.float32
|
||||
@@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
tp_q_head_num=q.size(1)
|
||||
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
||||
softmax_lse.fill_(float('-inf'))
|
||||
xtorch_ops.attention(
|
||||
kunlun_ops.attention(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
v_cache=maybe_padded_v,
|
||||
@@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
|
||||
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
|
||||
dtype=torch.float, device=kv_b_proj_weight.device)
|
||||
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||
kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||
self.W_UV = W_UV.contiguous()
|
||||
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
|
||||
else:
|
||||
@@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
@@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
sorted_tokens_idx = torch.arange(
|
||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||
extra_params = {"trans": False}
|
||||
xtorch_ops.mla_bmm_I8(
|
||||
kunlun_ops.mla_bmm_I8(
|
||||
decode_q_nope.contiguous(),
|
||||
self.W_UK_T,
|
||||
self.W_UK_SCALE,
|
||||
|
||||
Reference in New Issue
Block a user