Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)

Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
Xinyu Dong
2026-02-12 18:13:00 +08:00
committed by GitHub
parent 744719587e
commit bf9369f733
15 changed files with 125 additions and 119 deletions

View File

@@ -28,9 +28,9 @@ from typing import (
TypeVar,
)
import kunlun_ops
import numpy as np
import torch
import xtorch_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import (
AttentionType,
)
from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
CommonAttentionMetadata,
@@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
def __post_init__(self):
"""__post_init__"""
self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
@property
def is_all_encoder_attn_metadata_set(self):
@@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder:
"""build"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
common_prefix_len = common_prefix_len
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
self.device, non_blocking=True
@@ -771,7 +771,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# not cached. This happens during the initial memory
value = value.contiguous()
if key_cache.is_contiguous():
xtorch_ops.reshape_and_cache(
kunlun_ops.reshape_and_cache(
key[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens],
key_cache,
@@ -781,7 +781,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
xtorch_ops.reshape_and_cache_flash(
kunlun_ops.reshape_and_cache_flash(
key,
value,
cast_key_cache,
@@ -791,7 +791,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
@@ -811,7 +810,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Prefix cache
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
xtorch_ops.prefill_attention(
kunlun_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache [block_num, head, block_size, dim]
v=value_cache,
@@ -827,7 +826,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
softmax_lse=None,
)
else:
xtorch_ops.prefill_attention(
kunlun_ops.prefill_attention(
q=prefill_query,
k=prefill_key,
v=prefill_value,
@@ -860,9 +859,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
decode_meta.block_tables * 2
) # only test in Qwen3-Next
sig = inspect.signature(xtorch_ops.speculative_attention)
sig = inspect.signature(kunlun_ops.speculative_attention)
if "max_window_size" in sig.parameters:
xtorch_ops.speculative_attention(
kunlun_ops.speculative_attention(
out=output[:num_decode_tokens],
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
@@ -890,7 +889,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
),
)
elif not attn_metadata.is_speculative:
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
@@ -910,7 +909,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
out = output[:num_decode_tokens]
assert out.is_contiguous()
xtorch_ops.speculative_attention(
kunlun_ops.speculative_attention(
out=out.view(batch_size, qlen, head_num, self.head_size),
q=decode_query.view(batch_size, qlen, head_num, head_dim),
k_cache=key_cache,