[Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
This commit is contained in:
@@ -148,7 +148,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# Prefix cache loc
|
||||
kv_lod_cpu: Optional[torch.Tensor] = None
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None
|
||||
@@ -563,9 +562,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -673,51 +669,84 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
|
||||
# For hybrid Attention (Qwen3-Next.)
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = prefill_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=tmp_block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
sink=self.sinks
|
||||
)
|
||||
# For hybrid Attention (Qwen3-Next)
|
||||
tmp_block_tables = prefill_meta.block_tables * 2
|
||||
|
||||
# Prefix cache
|
||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=tmp_block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None
|
||||
)
|
||||
else:
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
swa_left = self.sliding_window if self.sliding_window is not None else -1,
|
||||
swa_right = 0 if self.sliding_window is not None else -1,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
decode_query = query[:num_decode_tokens]
|
||||
|
||||
# For hybrid Attention (Qwen3-Next
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = decode_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=tmp_block_tables,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
is_context=False,
|
||||
is_causal=True,
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
vo_head_dim=self.head_size
|
||||
)
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=decode_meta.block_tables.shape[0],
|
||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||
qlen=1,
|
||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||
max_context_len=131072,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
||||
block_tables=tmp_block_tables,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
def use_cascade_attention(
|
||||
@@ -788,4 +817,4 @@ def use_cascade_attention(
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
return cascade_time < flash_decoding_time
|
||||
Reference in New Issue
Block a user