forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -66,6 +66,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
|
||||
"Llama4ForConditionalGeneration": ("llama4", "Llama4ForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
|
||||
@@ -590,16 +590,24 @@ def unified_flash_attention_v2(
|
||||
pass
|
||||
else:
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(
|
||||
mlu_ops.quant_to_linear_cache(
|
||||
key, value_to_cache,
|
||||
key_cache, value_cache,
|
||||
key_cache_scale, value_cache_scale,
|
||||
updated_slot_mapping.flatten())
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, None,
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(
|
||||
mlu_ops.reshape_linear_cache(
|
||||
key, value_to_cache,
|
||||
key_cache, value_cache,
|
||||
updated_slot_mapping.flatten())
|
||||
attn_metadata.cu_seq_lens,
|
||||
attn_metadata.max_seq_len,
|
||||
True, None,
|
||||
attn_metadata.batch_ids,
|
||||
attn_metadata.slot_mapping_unpaged)
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
|
||||
Reference in New Issue
Block a user