forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3
This commit is contained in:
@@ -580,34 +580,49 @@ def unified_flash_attention_v2(
|
|||||||
value_cache,
|
value_cache,
|
||||||
updated_slot_mapping.flatten())
|
updated_slot_mapping.flatten())
|
||||||
else:
|
else:
|
||||||
# FIXME: After TMO-1496 is completed, remove this code.
|
# unpaged (linear cache) path
|
||||||
if key.stride() != value.stride():
|
if use_mla:
|
||||||
key = key.contiguous()
|
# MLA: 镜像 paged 路径的处理方式
|
||||||
value = value.contiguous()
|
value_to_cache = None
|
||||||
if kv_cache_dtype == 'int8':
|
if attn_metadata.prefill_metadata:
|
||||||
mlu_ops.quant_to_linear_cache(key,
|
# MLA prefill cache 已在 forward_prefill 中写入,跳过
|
||||||
value,
|
pass
|
||||||
key_cache,
|
else:
|
||||||
value_cache,
|
if kv_cache_dtype == 'int8':
|
||||||
key_cache_scale,
|
mlu_ops.quant_to_paged_cache(
|
||||||
value_cache_scale,
|
key, value_to_cache,
|
||||||
attn_metadata.cu_seq_lens,
|
key_cache, value_cache,
|
||||||
attn_metadata.max_seq_len,
|
key_cache_scale, value_cache_scale,
|
||||||
True, # packed
|
updated_slot_mapping.flatten())
|
||||||
None, # context_seq_offset
|
else:
|
||||||
attn_metadata.batch_ids,
|
mlu_ops.reshape_paged_cache(
|
||||||
attn_metadata.slot_mapping_unpaged)
|
key, value_to_cache,
|
||||||
|
key_cache, value_cache,
|
||||||
|
updated_slot_mapping.flatten())
|
||||||
else:
|
else:
|
||||||
mlu_ops.reshape_linear_cache(key,
|
# FIXME: After TMO-1496 is completed, remove this code.
|
||||||
value,
|
if key.stride() != value.stride():
|
||||||
key_cache,
|
key = key.contiguous()
|
||||||
value_cache,
|
value = value.contiguous()
|
||||||
attn_metadata.cu_seq_lens,
|
if kv_cache_dtype == 'int8':
|
||||||
attn_metadata.max_seq_len,
|
mlu_ops.quant_to_linear_cache(
|
||||||
True, # packed
|
key, value,
|
||||||
None, # context_seq_offset
|
key_cache, value_cache,
|
||||||
attn_metadata.batch_ids,
|
key_cache_scale, value_cache_scale,
|
||||||
attn_metadata.slot_mapping_unpaged)
|
attn_metadata.cu_seq_lens,
|
||||||
|
attn_metadata.max_seq_len,
|
||||||
|
True, None,
|
||||||
|
attn_metadata.batch_ids,
|
||||||
|
attn_metadata.slot_mapping_unpaged)
|
||||||
|
else:
|
||||||
|
mlu_ops.reshape_linear_cache(
|
||||||
|
key, value,
|
||||||
|
key_cache, value_cache,
|
||||||
|
attn_metadata.cu_seq_lens,
|
||||||
|
attn_metadata.max_seq_len,
|
||||||
|
True, None,
|
||||||
|
attn_metadata.batch_ids,
|
||||||
|
attn_metadata.slot_mapping_unpaged)
|
||||||
if use_mla and attn_metadata.prefill_metadata:
|
if use_mla and attn_metadata.prefill_metadata:
|
||||||
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
|
output = torch.empty(query.shape[0], query.shape[1], v_head_size, dtype=query.dtype, device="mlu")
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user