add deepseekv3

This commit is contained in:
Chranos
2026-02-11 13:18:03 +08:00
parent 98003e6f8b
commit 659ef273c8

View File

@@ -582,23 +582,15 @@ def unified_flash_attention_v2(
else: else:
# unpaged (linear cache) path # unpaged (linear cache) path
if use_mla: if use_mla:
# MLA: 镜像 paged 路径的处理方式 # MLA cache 是 2D (total_slots, head_dim)
value_to_cache = None # 不能用 reshape_paged_cache期望 4D直接索引写入
if attn_metadata.prefill_metadata: if attn_metadata.prefill_metadata:
# MLA prefill cache 已在 forward_prefill 中写入,跳过 # MLA prefill cache 已在 forward_prefill 中写入,跳过
pass pass
else: else:
if kv_cache_dtype == 'int8': # key: (num_tokens, 1, head_dim) → squeeze → (num_tokens, head_dim)
mlu_ops.quant_to_paged_cache( # key_cache: (total_slots, head_dim)
key, value_to_cache, key_cache[updated_slot_mapping.flatten()] = key.squeeze(1)
key_cache, value_cache,
key_cache_scale, value_cache_scale,
updated_slot_mapping.flatten())
else:
mlu_ops.reshape_paged_cache(
key, value_to_cache,
key_cache, value_cache,
updated_slot_mapping.flatten())
else: else:
# FIXME: After TMO-1496 is completed, remove this code. # FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride(): if key.stride() != value.stride():