[Bug fix] Fix Gemma 2 and fix Gemma 3 multimodal with bs > 1 on NPU (#9871)
Co-authored-by: Maksim <makcum888e@mail.ru>
This commit is contained in:
@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
||||
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user