enable rmsnorm on XPU (#10248)
This commit is contained in:
@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_logger,
|
configure_logger,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
is_cuda_alike,
|
||||||
|
is_xpu,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
require_mlp_sync,
|
require_mlp_sync,
|
||||||
require_mlp_tp_gather,
|
require_mlp_tp_gather,
|
||||||
@@ -80,6 +82,15 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||||
|
|
||||||
|
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
|
||||||
|
profiler_activity
|
||||||
|
for available, profiler_activity in [
|
||||||
|
(is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
|
||||||
|
(is_xpu(), torch.profiler.ProfilerActivity.XPU),
|
||||||
|
]
|
||||||
|
if available
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BenchArgs:
|
class BenchArgs:
|
||||||
@@ -424,10 +435,7 @@ def latency_test_run_once(
|
|||||||
profiler = None
|
profiler = None
|
||||||
if profile:
|
if profile:
|
||||||
profiler = torch.profiler.profile(
|
profiler = torch.profiler.profile(
|
||||||
activities=[
|
activities=profile_activities,
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
record_shapes=profile_record_shapes,
|
record_shapes=profile_record_shapes,
|
||||||
)
|
)
|
||||||
@@ -460,10 +468,7 @@ def latency_test_run_once(
|
|||||||
if profile and i == output_len / 2:
|
if profile and i == output_len / 2:
|
||||||
profiler = None
|
profiler = None
|
||||||
profiler = torch.profiler.profile(
|
profiler = torch.profiler.profile(
|
||||||
activities=[
|
activities=profile_activities,
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
record_shapes=profile_record_shapes,
|
record_shapes=profile_record_shapes,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -52,8 +52,13 @@ if _is_cuda:
|
|||||||
gemma_rmsnorm,
|
gemma_rmsnorm,
|
||||||
rmsnorm,
|
rmsnorm,
|
||||||
)
|
)
|
||||||
|
elif _is_xpu:
|
||||||
|
from sgl_kernel import (
|
||||||
|
fused_add_rmsnorm,
|
||||||
|
gemma_fused_add_rmsnorm,
|
||||||
|
gemma_rmsnorm,
|
||||||
|
rmsnorm,
|
||||||
|
)
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import rmsnorm2d_fwd as rms_norm
|
from aiter import rmsnorm2d_fwd as rms_norm
|
||||||
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
||||||
@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
|
|||||||
else:
|
else:
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if self.variance_size_override is not None:
|
||||||
|
return self.forward_native(x, residual)
|
||||||
|
if residual is not None:
|
||||||
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
||||||
|
return x, residual
|
||||||
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||||
|
return out
|
||||||
|
|
||||||
def forward_with_allreduce_fusion(
|
def forward_with_allreduce_fusion(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
if _is_hip:
|
if _is_hip:
|
||||||
self._forward_method = self.forward_native
|
self._forward_method = self.forward_native
|
||||||
|
|
||||||
|
def _forward_impl(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if residual is not None:
|
||||||
|
gemma_fused_add_rmsnorm(
|
||||||
|
x, residual, self.weight.data, self.variance_epsilon
|
||||||
|
)
|
||||||
|
return x, residual
|
||||||
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||||
|
return out
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
if residual is not None:
|
return self._forward_impl(x, residual)
|
||||||
gemma_fused_add_rmsnorm(
|
|
||||||
x, residual, self.weight.data, self.variance_epsilon
|
|
||||||
)
|
|
||||||
return x, residual
|
|
||||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def forward_npu(
|
def forward_npu(
|
||||||
self,
|
self,
|
||||||
@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
||||||
return x if residual is None else (x, residual)
|
return x if residual is None else (x, residual)
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
return self._forward_impl(x, residual)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3RMSNorm(CustomOp):
|
class Gemma3RMSNorm(CustomOp):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
|||||||
Reference in New Issue
Block a user