diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py index 495b2a4..c147c1d 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py @@ -176,14 +176,14 @@ def vllm__llama4__Llama4Attention__forward( ================== ''' - # QK norm (教训 #2: use contiguous + reshape) + # QK norm (MLU fused_rms_norm requires matching dtypes, skip .float()) if self.qk_norm is not None: q = q.contiguous().reshape(-1, self.head_dim) - q = (self.qk_norm(q.float()) - .contiguous().reshape(-1, self.q_size).to(q.dtype)) + q = (self.qk_norm(q) + .contiguous().reshape(-1, self.q_size)) k = k.contiguous().reshape(-1, self.head_dim) - k = (self.qk_norm(k.float()) - .contiguous().reshape(-1, self.kv_size).to(k.dtype)) + k = (self.qk_norm(k) + .contiguous().reshape(-1, self.kv_size)) # Temperature tuning for NoPE layers if self.attn_temperature_tuning and self.nope: