[feat]Ascend NPU Gemma-3-12b and Gemma-3-27b support (#8909)

This commit is contained in:
VDV1985
2025-08-31 10:25:07 +03:00
committed by GitHub
parent c112bcc461
commit ba861293cf
6 changed files with 136 additions and 30 deletions

View File

@@ -53,7 +53,7 @@ elif _is_hip:
logger = logging.getLogger(__name__)
if is_npu():
if _is_npu:
import torch_npu
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_npu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x
class Gemma3RMSNorm(nn.Module):
x = x.float()
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)
class Gemma3RMSNorm(CustomOp):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
# Re-dispatch
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
def forward_native(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def forward_cuda(self, x):
return self.forward_native(x)
def forward_npu(self, x):
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
return output
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"