From 98003e6f8b8bf8d1d71b9b7a0812435216232641 Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Wed, 11 Feb 2026 13:12:46 +0800 Subject: [PATCH] add deepseekv3 --- .../vllm_mlu/attention/backends/mlu_attn.py | 69 +++++++++++-------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py index 5aba1fa..8edb1fd 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py @@ -580,34 +580,49 @@ def unified_flash_attention_v2( value_cache, updated_slot_mapping.flatten()) else: - # FIXME: After TMO-1496 is completed, remove this code. - if key.stride() != value.stride(): - key = key.contiguous() - value = value.contiguous() - if kv_cache_dtype == 'int8': - mlu_ops.quant_to_linear_cache(key, - value, - key_cache, - value_cache, - key_cache_scale, - value_cache_scale, - attn_metadata.cu_seq_lens, - attn_metadata.max_seq_len, - True, # packed - None, # context_seq_offset - attn_metadata.batch_ids, - attn_metadata.slot_mapping_unpaged) + # unpaged (linear cache) path + if use_mla: + # MLA: 镜像 paged 路径的处理方式 + value_to_cache = None + if attn_metadata.prefill_metadata: + # MLA prefill cache 已在 forward_prefill 中写入,跳过 + pass + else: + if kv_cache_dtype == 'int8': + mlu_ops.quant_to_paged_cache( + key, value_to_cache, + key_cache, value_cache, + key_cache_scale, value_cache_scale, + updated_slot_mapping.flatten()) + else: + mlu_ops.reshape_paged_cache( + key, value_to_cache, + key_cache, value_cache, + updated_slot_mapping.flatten()) else: - mlu_ops.reshape_linear_cache(key, - value, - key_cache, - value_cache, - attn_metadata.cu_seq_lens, - attn_metadata.max_seq_len, - True, # packed - None, # context_seq_offset - attn_metadata.batch_ids, - attn_metadata.slot_mapping_unpaged) + # FIXME: After TMO-1496 is completed, remove this code. + if key.stride() != value.stride(): + key = key.contiguous() + value = value.contiguous() + if kv_cache_dtype == 'int8': + mlu_ops.quant_to_linear_cache( + key, value, + key_cache, value_cache, + key_cache_scale, value_cache_scale, + attn_metadata.cu_seq_lens, + attn_metadata.max_seq_len, + True, None, + attn_metadata.batch_ids, + attn_metadata.slot_mapping_unpaged) + else: + mlu_ops.reshape_linear_cache( + key, value, + key_cache, value_cache, + attn_metadata.cu_seq_lens, + attn_metadata.max_seq_len, + True, None, + attn_metadata.batch_ids, + attn_metadata.slot_mapping_unpaged) if use_mla and attn_metadata.prefill_metadata: output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu") else: