add deepseekv3 and llama4

This commit is contained in:
Chranos
2026-02-11 14:26:59 +08:00
parent 128aed196c
commit 8ac7afcbd3
2 changed files with 13 additions and 4 deletions

View File

@@ -66,6 +66,7 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
"Llama4ForConditionalGeneration": ("llama4", "Llama4ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),

View File

@@ -590,16 +590,24 @@ def unified_flash_attention_v2(
pass pass
else: else:
if kv_cache_dtype == 'int8': if kv_cache_dtype == 'int8':
mlu_ops.quant_to_paged_cache( mlu_ops.quant_to_linear_cache(
key, value_to_cache, key, value_to_cache,
key_cache, value_cache, key_cache, value_cache,
key_cache_scale, value_cache_scale, key_cache_scale, value_cache_scale,
updated_slot_mapping.flatten()) attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else: else:
mlu_ops.reshape_paged_cache( mlu_ops.reshape_linear_cache(
key, value_to_cache, key, value_to_cache,
key_cache, value_cache, key_cache, value_cache,
updated_slot_mapping.flatten()) attn_metadata.cu_seq_lens,
attn_metadata.max_seq_len,
True, None,
attn_metadata.batch_ids,
attn_metadata.slot_mapping_unpaged)
else: else:
# FIXME: After TMO-1496 is completed, remove this code. # FIXME: After TMO-1496 is completed, remove this code.
if key.stride() != value.stride(): if key.stride() != value.stride():