[Bugfix] fix the bug of the flash_attention in Qwen3-Next
This commit is contained in:
@@ -673,6 +673,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||||
|
|
||||||
|
if key_cache.is_contiguous():
|
||||||
|
tmp_block_tables = prefill_meta.block_tables
|
||||||
|
else:
|
||||||
|
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
|
||||||
|
|
||||||
xtorch_ops.prefill_attention(
|
xtorch_ops.prefill_attention(
|
||||||
q=prefill_query,
|
q=prefill_query,
|
||||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
||||||
@@ -680,7 +686,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
is_prefix_cache=True,
|
is_prefix_cache=True,
|
||||||
block_table=prefill_meta.block_tables,
|
block_table=tmp_block_tables,
|
||||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||||
@@ -782,4 +788,4 @@ def use_cascade_attention(
|
|||||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||||
|
|
||||||
# Use cascade attention if it is faster than FlashDecoding.
|
# Use cascade attention if it is faster than FlashDecoding.
|
||||||
return cascade_time < flash_decoding_time
|
return cascade_time < flash_decoding_time
|
||||||
|
|||||||
Reference in New Issue
Block a user