add deepseekv3

This commit is contained in:
Chranos
2026-02-11 13:12:46 +08:00
parent 094541296e
commit 98003e6f8b

View File

@@ -580,34 +580,49 @@ def unified_flash_attention_v2(
value_cache, value_cache,
updated_slot_mapping.flatten()) updated_slot_mapping.flatten())
else: else:
# FIXME: After TMO-1496 is completed, remove this code. # unpaged (linear cache) path
if key.stride() != value.stride(): if use_mla:
key = key.contiguous() # MLA: 镜像 paged 路径的处理方式
value = value.contiguous() value_to_cache = None
if kv_cache_dtype == 'int8': if attn_metadata.prefill_metadata:
mlu_ops.quant_to_linear_cache(key, # MLA prefill cache 已在 forward_prefill 中写入,跳过
value, pass
key_cache, else:
value_cache, if kv_cache_dtype == 'int8':
key_cache_scale, mlu_ops.quant_to_paged_cache(
value_cache_scale, key, value_to_cache,
attn_metadata.cu_seq_lens, key_cache, value_cache,
attn_metadata.max_seq_len, key_cache_scale, value_cache_scale,
True, # packed updated_slot_mapping.flatten())
None, # context_seq_offset else:
attn_metadata.batch_ids, mlu_ops.reshape_paged_cache(
attn_metadata.slot_mapping_unpaged) key, value_to_cache,
key_cache, value_cache,
updated_slot_mapping.flatten())
else: else:
mlu_ops.reshape_linear_cache(key, # FIXME: After TMO-1496 is completed, remove this code.
value, if key.stride() != value.stride():
key_cache, key = key.contiguous()
value_cache, value = value.contiguous()
attn_metadata.cu_seq_lens, if kv_cache_dtype == 'int8':
attn_metadata.max_seq_len, mlu_ops.quant_to_linear_cache(
True, # packed key, value,
None, # context_seq_offset key_cache, value_cache,
attn_metadata.batch_ids, key_cache_scale, value_cache_scale,
attn_metadata.slot_mapping_unpaged) attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else:
mlu_ops.reshape_linear_cache(
key, value,
key_cache, value_cache,
attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
if use_mla and attn_metadata.prefill_metadata: if use_mla and attn_metadata.prefill_metadata:
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu") output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
else: else: