diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py index 8edb1fd..a75fb3d 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py @@ -582,23 +582,15 @@ def unified_flash_attention_v2( else: # unpaged (linear cache) path if use_mla: - # MLA: 镜像 paged 路径的处理方式 - value_to_cache = None + # MLA cache 是 2D (total_slots, head_dim), + # 不能用 reshape_paged_cache(期望 4D),直接索引写入 if attn_metadata.prefill_metadata: # MLA prefill cache 已在 forward_prefill 中写入,跳过 pass else: - if kv_cache_dtype == 'int8': - mlu_ops.quant_to_paged_cache( - key, value_to_cache, - key_cache, value_cache, - key_cache_scale, value_cache_scale, - updated_slot_mapping.flatten()) - else: - mlu_ops.reshape_paged_cache( - key, value_to_cache, - key_cache, value_cache, - updated_slot_mapping.flatten()) + # key: (num_tokens, 1, head_dim) → squeeze → (num_tokens, head_dim) + # key_cache: (total_slots, head_dim) + key_cache[updated_slot_mapping.flatten()] = key.squeeze(1) else: # FIXME: After TMO-1496 is completed, remove this code. if key.stride() != value.stride():