forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -143,11 +143,14 @@ class RMSNorm(CustomOp):
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
x = x.view(-1, self.weight.data.shape[0])
|
||||
weight = self.weight.data
|
||||
if weight.dtype != x.dtype:
|
||||
weight = weight.to(x.dtype)
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, self.weight.data.shape[0])
|
||||
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, True)
|
||||
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, True)
|
||||
else:
|
||||
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, False)
|
||||
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, False)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"hidden_size={self.weight.data.size(0)}"
|
||||
|
||||
Reference in New Issue
Block a user