diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 612ca71..4f2555e 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -673,6 +673,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] + + if key_cache.is_contiguous(): + tmp_block_tables = prefill_meta.block_tables + else: + tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next + xtorch_ops.prefill_attention( q=prefill_query, k=key_cache, # Key Cache (block_num, head, block_size, dim) @@ -680,7 +686,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): out=output[num_decode_tokens:attn_metadata.num_actual_tokens], is_causal=True, is_prefix_cache=True, - block_table=prefill_meta.block_tables, + block_table=tmp_block_tables, context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_xpu=prefill_meta.query_start_loc, context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, @@ -782,4 +788,4 @@ def use_cascade_attention( flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time \ No newline at end of file + return cascade_time < flash_decoding_time