[feat]Ascend NPU Gemma-3-12b and Gemma-3-27b support (#8909)
This commit is contained in:
@@ -53,7 +53,7 @@ elif _is_hip:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_npu():
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
class Gemma3RMSNorm(nn.Module):
|
||||
x = x.float()
|
||||
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
||||
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
|
||||
class Gemma3RMSNorm(CustomOp):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
# Re-dispatch
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
def forward_native(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
def forward_cuda(self, x):
|
||||
return self.forward_native(x)
|
||||
|
||||
def forward_npu(self, x):
|
||||
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user