Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -28,9 +28,9 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import kunlun_ops
|
||||
import numpy as np
|
||||
import torch
|
||||
import xtorch_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
CommonAttentionMetadata,
|
||||
@@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def __post_init__(self):
|
||||
"""__post_init__"""
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
@@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder:
|
||||
"""build"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
common_prefix_len = common_prefix_len
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
||||
self.device, non_blocking=True
|
||||
@@ -771,7 +771,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# not cached. This happens during the initial memory
|
||||
value = value.contiguous()
|
||||
if key_cache.is_contiguous():
|
||||
xtorch_ops.reshape_and_cache(
|
||||
kunlun_ops.reshape_and_cache(
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
@@ -781,7 +781,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
cast_key_cache,
|
||||
@@ -791,7 +791,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
@@ -811,7 +810,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Prefix cache
|
||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||
xtorch_ops.prefill_attention(
|
||||
kunlun_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
v=value_cache,
|
||||
@@ -827,7 +826,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
softmax_lse=None,
|
||||
)
|
||||
else:
|
||||
xtorch_ops.prefill_attention(
|
||||
kunlun_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
@@ -860,9 +859,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
decode_meta.block_tables * 2
|
||||
) # only test in Qwen3-Next
|
||||
|
||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
||||
sig = inspect.signature(kunlun_ops.speculative_attention)
|
||||
if "max_window_size" in sig.parameters:
|
||||
xtorch_ops.speculative_attention(
|
||||
kunlun_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
@@ -890,7 +889,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
),
|
||||
)
|
||||
elif not attn_metadata.is_speculative:
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -910,7 +909,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
out = output[:num_decode_tokens]
|
||||
assert out.is_contiguous()
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
kunlun_ops.speculative_attention(
|
||||
out=out.view(batch_size, qlen, head_num, self.head_size),
|
||||
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
||||
k_cache=key_cache,
|
||||
|
||||
@@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
) * q_len
|
||||
sorted_tokens_idx = torch.arange(
|
||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||
xtorch_ops.mla_bmm_I8(
|
||||
kunlun_ops.mla_bmm_I8(
|
||||
x.contiguous(), # [1, 16, 512] torch.float16
|
||||
self.W_UV, # [16, 128, 512] torch.int8
|
||||
self.W_UV_SCALE, # [2048, 1] torch.float32
|
||||
@@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
tp_q_head_num=q.size(1)
|
||||
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
||||
softmax_lse.fill_(float('-inf'))
|
||||
xtorch_ops.attention(
|
||||
kunlun_ops.attention(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
v_cache=maybe_padded_v,
|
||||
@@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
|
||||
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
|
||||
dtype=torch.float, device=kv_b_proj_weight.device)
|
||||
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||
kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||
self.W_UV = W_UV.contiguous()
|
||||
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
|
||||
else:
|
||||
@@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
@@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
sorted_tokens_idx = torch.arange(
|
||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||
extra_params = {"trans": False}
|
||||
xtorch_ops.mla_bmm_I8(
|
||||
kunlun_ops.mla_bmm_I8(
|
||||
decode_q_nope.contiguous(),
|
||||
self.W_UK_T,
|
||||
self.W_UK_SCALE,
|
||||
|
||||
Reference in New Issue
Block a user