[Bugfix] fix the bug of the flash_attention in Qwen3-Next
This commit is contained in:
@@ -673,6 +673,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = prefill_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
||||
@@ -680,7 +686,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=prefill_meta.block_tables,
|
||||
block_table=tmp_block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
@@ -782,4 +788,4 @@ def use_cascade_attention(
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
return cascade_time < flash_decoding_time
|
||||
|
||||
Reference in New Issue
Block a user