Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -28,9 +28,9 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import kunlun_ops
|
||||
import numpy as np
|
||||
import torch
|
||||
import xtorch_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
CommonAttentionMetadata,
|
||||
@@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def __post_init__(self):
|
||||
"""__post_init__"""
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
@@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder:
|
||||
"""build"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
common_prefix_len = common_prefix_len
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
||||
self.device, non_blocking=True
|
||||
@@ -771,7 +771,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# not cached. This happens during the initial memory
|
||||
value = value.contiguous()
|
||||
if key_cache.is_contiguous():
|
||||
xtorch_ops.reshape_and_cache(
|
||||
kunlun_ops.reshape_and_cache(
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
@@ -781,7 +781,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
cast_key_cache,
|
||||
@@ -791,7 +791,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
@@ -811,7 +810,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Prefix cache
|
||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||
xtorch_ops.prefill_attention(
|
||||
kunlun_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
v=value_cache,
|
||||
@@ -827,7 +826,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
softmax_lse=None,
|
||||
)
|
||||
else:
|
||||
xtorch_ops.prefill_attention(
|
||||
kunlun_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
@@ -860,9 +859,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
decode_meta.block_tables * 2
|
||||
) # only test in Qwen3-Next
|
||||
|
||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
||||
sig = inspect.signature(kunlun_ops.speculative_attention)
|
||||
if "max_window_size" in sig.parameters:
|
||||
xtorch_ops.speculative_attention(
|
||||
kunlun_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
@@ -890,7 +889,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
),
|
||||
)
|
||||
elif not attn_metadata.is_speculative:
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -910,7 +909,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
out = output[:num_decode_tokens]
|
||||
assert out.is_contiguous()
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
kunlun_ops.speculative_attention(
|
||||
out=out.view(batch_size, qlen, head_num, self.head_size),
|
||||
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
||||
k_cache=key_cache,
|
||||
|
||||
Reference in New Issue
Block a user