enable rmsnorm on XPU (#10248)
This commit is contained in:
@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_bool_env_var,
|
||||
is_cuda_alike,
|
||||
is_xpu,
|
||||
kill_process_tree,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
@@ -80,6 +82,15 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||
|
||||
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
|
||||
profiler_activity
|
||||
for available, profiler_activity in [
|
||||
(is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
|
||||
(is_xpu(), torch.profiler.ProfilerActivity.XPU),
|
||||
]
|
||||
if available
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
@@ -424,10 +435,7 @@ def latency_test_run_once(
|
||||
profiler = None
|
||||
if profile:
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
activities=profile_activities,
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
@@ -460,10 +468,7 @@ def latency_test_run_once(
|
||||
if profile and i == output_len / 2:
|
||||
profiler = None
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
activities=profile_activities,
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
|
||||
@@ -52,8 +52,13 @@ if _is_cuda:
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
|
||||
elif _is_xpu:
|
||||
from sgl_kernel import (
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
if _use_aiter:
|
||||
from aiter import rmsnorm2d_fwd as rms_norm
|
||||
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
||||
@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
|
||||
else:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_with_allreduce_fusion(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
|
||||
if _is_hip:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
gemma_fused_add_rmsnorm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
return x, residual
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
gemma_fused_add_rmsnorm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
return x, residual
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
return self._forward_impl(x, residual)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
|
||||
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
return self._forward_impl(x, residual)
|
||||
|
||||
|
||||
class Gemma3RMSNorm(CustomOp):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
|
||||
Reference in New Issue
Block a user