diff --git a/vllm-v0.6.2/vllm/model_executor/models/registry.py b/vllm-v0.6.2/vllm/model_executor/models/registry.py index b805401..276fba1 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/registry.py +++ b/vllm-v0.6.2/vllm/model_executor/models/registry.py @@ -66,6 +66,7 @@ _TEXT_GENERATION_MODELS = { "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), + "Llama4ForConditionalGeneration": ("llama4", "Llama4ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py index 82932ff..c7cc195 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/attention/backends/mlu_attn.py @@ -590,16 +590,24 @@ def unified_flash_attention_v2( pass else: if kv_cache_dtype == 'int8': - mlu_ops.quant_to_paged_cache( + mlu_ops.quant_to_linear_cache( key, value_to_cache, key_cache, value_cache, key_cache_scale, value_cache_scale, - updated_slot_mapping.flatten()) + attn_metadata.cu_seq_lens, + attn_metadata.max_seq_len, + True, None, + attn_metadata.batch_ids, + attn_metadata.slot_mapping_unpaged) else: - mlu_ops.reshape_paged_cache( + mlu_ops.reshape_linear_cache( key, value_to_cache, key_cache, value_cache, - updated_slot_mapping.flatten()) + attn_metadata.cu_seq_lens, + attn_metadata.max_seq_len, + True, None, + attn_metadata.batch_ids, + attn_metadata.slot_mapping_unpaged) else: # FIXME: After TMO-1496 is completed, remove this code. if key.stride() != value.stride():