diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 39c1c2681..5b502a153 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,11 +1,12 @@ from torch import nn -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu _is_cuda = is_cuda() _is_hip = is_hip() _is_cpu = is_cpu() _is_cpu_amx_available = cpu_has_amx_support() +_is_npu = is_npu() class CustomOp(nn.Module): @@ -60,6 +61,9 @@ class CustomOp(nn.Module): def forward_cuda(self, *args, **kwargs): raise NotImplementedError + def forward_npu(self, *args, **kwargs): + raise NotImplementedError + def forward_hip(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @@ -79,5 +83,7 @@ class CustomOp(nn.Module): return self.forward_hip elif _is_cpu and _is_cpu_amx_available: return self.forward_cpu + elif _is_npu: + return self.forward_npu else: return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index a9e3436a6..d84d3eda5 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -48,6 +48,9 @@ if _is_cuda: logger = logging.getLogger(__name__) +if is_npu(): + import torch_npu + class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -70,6 +73,10 @@ class SiluAndMul(CustomOp): else: return self.forward_native(x) + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: + out = torch_npu.npu_swiglu(x) + return out + class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 2277a70af..5d8106f17 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -52,6 +52,9 @@ elif _is_hip: logger = logging.getLogger(__name__) +if is_npu(): + import torch_npu + class RMSNorm(CustomOp): def __init__( @@ -76,6 +79,18 @@ class RMSNorm(CustomOp): out = rmsnorm(x, self.weight.data, self.variance_epsilon) return out + def forward_npu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + out, _, residual_out = torch_npu.npu_add_rms_norm( + residual, x, self.weight.data, self.variance_epsilon + ) + return out, residual_out + return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0] + def forward_aiter( self, x: torch.Tensor, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index bd145a4b0..f931098be 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -8,7 +8,14 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_cuda, + is_hip, + is_npu, +) _is_cuda = is_cuda() _is_hip = is_hip() @@ -19,6 +26,9 @@ _is_cpu = is_cpu() if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace +if is_npu(): + import torch_npu + def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] @@ -152,6 +162,36 @@ class RotaryEmbedding(CustomOp): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_npu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-npu implementation of forward().""" + import os + + if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): + return self.forward_native(positions, query, key, offsets) + else: + rotary_mode = "half" + if self.is_neox_style: + rotary_mode = "half" + else: + rotary_mode = "interleave" + mrope_section = [0, 0, 0] + query_out, key_out = torch_npu.npu_mrope( + positions, + query, + key, + self.cos_sin_cache, + self.head_size, + mrope_section=mrope_section, + rotary_mode=rotary_mode, + ) + return query_out, key_out + def forward_cpu( self, positions: torch.Tensor,