fix deepseek torchair precision (#3635)
### What this PR does / why we need it? The precision of deepseek torchair is broken by #3465 , which due to the origin patch or rmsnorm in torchair. This PR fixes the precision of deepseek torchair. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -18,6 +18,32 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
_original_re_init = RMSNorm.__init__
|
||||
|
||||
|
||||
def torchair_rmsnorm_init_(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
_original_re_init(self, hidden_size, eps, var_hidden_size, has_weight,
|
||||
dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
self.torch_npu_check = version_check()
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if self.torch_npu_check and vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def torchair_rmsnorm_forward_oot(
|
||||
@@ -33,6 +59,7 @@ def torchair_rmsnorm_forward_oot(
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
torch_npu_check = version_check()
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
@@ -45,7 +72,11 @@ def torchair_rmsnorm_forward_oot(
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if torch_npu_check and self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
if torch_npu_check and self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
|
||||
@@ -229,10 +229,12 @@ def torchair_ops_patch():
|
||||
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
|
||||
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]
|
||||
|
||||
AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
|
||||
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
|
||||
|
||||
|
||||
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
||||
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||
|
||||
Reference in New Issue
Block a user