[Feature] Support XiaoMi MIMO Flash V2 (#62)

* [Feature] Support MIMO Flash V2
This commit is contained in:
Xinyu Dong
2025-12-31 10:16:33 +08:00
committed by GitHub
parent 341dc7f296
commit b3c30a3cb9
12 changed files with 1530 additions and 690 deletions

View File

@@ -148,7 +148,6 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Prefix cache loc
kv_lod_cpu: Optional[torch.Tensor] = None
kv_lod_xpu: Optional[torch.Tensor] = None
@@ -563,9 +562,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if blocksparse_params is not None:
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -673,51 +669,84 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
# For hybrid Attention (Qwen3-Next.)
if key_cache.is_contiguous():
tmp_block_tables = prefill_meta.block_tables
else:
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache (block_num, head, block_size, dim)
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
sink=self.sinks
)
# For hybrid Attention (Qwen3-Next)
tmp_block_tables = prefill_meta.block_tables * 2
# Prefix cache
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache [block_num, head, block_size, dim]
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None
)
else:
xtorch_ops.prefill_attention(
q=prefill_query,
k=prefill_key,
v=prefill_value,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
swa_left = self.sliding_window if self.sliding_window is not None else -1,
swa_right = 0 if self.sliding_window is not None else -1,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
)
if decode_meta := attn_metadata.decode_metadata:
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
decode_query = query[:num_decode_tokens]
# For hybrid Attention (Qwen3-Next
if key_cache.is_contiguous():
tmp_block_tables = decode_meta.block_tables
else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=tmp_block_tables,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
is_context=False,
is_causal=True,
xtorch_ops.speculative_attention(
out=output[:num_decode_tokens],
vo_head_dim=self.head_size
)
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=decode_meta.block_tables.shape[0],
# TODO (@xyDong23): Support MTP(q lens >1)
qlen=1,
# TODO (@xyDong23): Support max_context_len to (262144)
max_context_len=131072,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
block_tables=tmp_block_tables,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(
@@ -788,4 +817,4 @@ def use_cascade_attention(
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
return cascade_time < flash_decoding_time