[Bug fix] Fix Gemma 2 and fix Gemma 3 multimodal with bs > 1 on NPU (#9871)

Co-authored-by: Maksim <makcum888e@mail.ru>
This commit is contained in:
ssshinigami
2025-09-08 11:19:40 +03:00
committed by GitHub
parent ee21817c6b
commit 5dd8c6444b
3 changed files with 11 additions and 9 deletions

View File

@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x
x = x.float()
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
return x if residual is None else (x, residual)