add deepseekv3
This commit is contained in:
@@ -582,23 +582,15 @@ def unified_flash_attention_v2(
|
||||
else:
|
||||
# unpaged (linear cache) path
|
||||
if use_mla:
|
||||
# MLA: 镜像 paged 路径的处理方式
|
||||
value_to_cache = None
|
||||
# MLA cache 是 2D (total_slots, head_dim),
|
||||
# 不能用 reshape_paged_cache(期望 4D),直接索引写入
|
||||
if attn_metadata.prefill_metadata:
|
||||
# MLA prefill cache 已在 forward_prefill 中写入,跳过
|
||||
pass
|
||||
else:
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(
|
||||
key, value_to_cache,
|
||||
key_cache, value_cache,
|
||||
key_cache_scale, value_cache_scale,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(
|
||||
key, value_to_cache,
|
||||
key_cache, value_cache,
|
||||
updated_slot_mapping.flatten())
|
||||
# key: (num_tokens, 1, head_dim) → squeeze → (num_tokens, head_dim)
|
||||
# key_cache: (total_slots, head_dim)
|
||||
key_cache[updated_slot_mapping.flatten()] = key.squeeze(1)
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
|
||||
Reference in New Issue
Block a user