npu fused op (#7386)
Co-authored-by: Li Junwen <lijunwen13@hisilicon.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
|
||||||
class CustomOp(nn.Module):
|
class CustomOp(nn.Module):
|
||||||
@@ -60,6 +61,9 @@ class CustomOp(nn.Module):
|
|||||||
def forward_cuda(self, *args, **kwargs):
|
def forward_cuda(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_npu(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward_hip(self, *args, **kwargs):
|
def forward_hip(self, *args, **kwargs):
|
||||||
return self.forward_cuda(*args, **kwargs)
|
return self.forward_cuda(*args, **kwargs)
|
||||||
|
|
||||||
@@ -79,5 +83,7 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_hip
|
return self.forward_hip
|
||||||
elif _is_cpu and _is_cpu_amx_available:
|
elif _is_cpu and _is_cpu_amx_available:
|
||||||
return self.forward_cpu
|
return self.forward_cpu
|
||||||
|
elif _is_npu:
|
||||||
|
return self.forward_npu
|
||||||
else:
|
else:
|
||||||
return self.forward_native
|
return self.forward_native
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ if _is_cuda:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if is_npu():
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
|
|||||||
else:
|
else:
|
||||||
return self.forward_native(x)
|
return self.forward_native(x)
|
||||||
|
|
||||||
|
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = torch_npu.npu_swiglu(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GeluAndMul(CustomOp):
|
class GeluAndMul(CustomOp):
|
||||||
def __init__(self, approximate="tanh"):
|
def __init__(self, approximate="tanh"):
|
||||||
|
|||||||
@@ -52,6 +52,9 @@ elif _is_hip:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if is_npu():
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
|
|||||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_npu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if residual is not None:
|
||||||
|
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
||||||
|
residual, x, self.weight.data, self.variance_epsilon
|
||||||
|
)
|
||||||
|
return out, residual_out
|
||||||
|
return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
||||||
|
|
||||||
def forward_aiter(
|
def forward_aiter(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@@ -8,7 +8,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
from sglang.srt.utils import (
|
||||||
|
cpu_has_amx_support,
|
||||||
|
get_bool_env_var,
|
||||||
|
is_cpu,
|
||||||
|
is_cuda,
|
||||||
|
is_hip,
|
||||||
|
is_npu,
|
||||||
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -19,6 +26,9 @@ _is_cpu = is_cpu()
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
|
|
||||||
|
if is_npu():
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@@ -152,6 +162,36 @@ class RotaryEmbedding(CustomOp):
|
|||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_npu(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""A PyTorch-npu implementation of forward()."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
||||||
|
return self.forward_native(positions, query, key, offsets)
|
||||||
|
else:
|
||||||
|
rotary_mode = "half"
|
||||||
|
if self.is_neox_style:
|
||||||
|
rotary_mode = "half"
|
||||||
|
else:
|
||||||
|
rotary_mode = "interleave"
|
||||||
|
mrope_section = [0, 0, 0]
|
||||||
|
query_out, key_out = torch_npu.npu_mrope(
|
||||||
|
positions,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
self.head_size,
|
||||||
|
mrope_section=mrope_section,
|
||||||
|
rotary_mode=rotary_mode,
|
||||||
|
)
|
||||||
|
return query_out, key_out
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user