From 094c116f7dc4a5a4d845ce812406e9514a275266 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Wed, 18 Jun 2025 13:11:50 +0800 Subject: [PATCH] Update python API of activation, topk, norm and rope and remove vllm dependency (#6614) Co-authored-by: Wu, Chunyuan Co-authored-by: jianan-gu Co-authored-by: sdp --- docker/Dockerfile.xeon | 1 + python/sglang/srt/custom_op.py | 6 +- python/sglang/srt/layers/activation.py | 23 ++++- python/sglang/srt/layers/layernorm.py | 30 +++++- .../layers/moe/fused_moe_triton/fused_moe.py | 6 ++ .../srt/layers/moe/fused_moe_triton/layer.py | 6 +- python/sglang/srt/layers/moe/topk.py | 95 ++++++++++++++++++- .../compressed_tensors_moe.py | 7 +- python/sglang/srt/layers/quantization/fp8.py | 6 +- .../sglang/srt/layers/quantization/utils.py | 6 +- python/sglang/srt/layers/rotary_embedding.py | 43 ++++++++- .../sglang/srt/model_executor/model_runner.py | 3 +- python/sglang/srt/models/deepseek_v2.py | 11 ++- python/sglang/srt/utils.py | 15 ++- test/srt/cpu/test_activation.py | 2 +- test/srt/cpu/test_gemm.py | 10 +- test/srt/cpu/test_moe.py | 8 +- test/srt/cpu/test_norm.py | 8 +- test/srt/cpu/test_qkv_proj_with_rope.py | 24 ++--- test/srt/cpu/test_rope.py | 4 +- test/srt/cpu/test_shared_expert.py | 6 +- test/srt/cpu/test_topk.py | 4 +- test/srt/run_suite.py | 2 + 23 files changed, 270 insertions(+), 56 deletions(-) diff --git a/docker/Dockerfile.xeon b/docker/Dockerfile.xeon index d622aa736..2f03b6485 100644 --- a/docker/Dockerfile.xeon +++ b/docker/Dockerfile.xeon @@ -39,6 +39,7 @@ RUN git clone https://github.com/sgl-project/sglang.git && \ cp pyproject_cpu.toml pyproject.toml && \ pip install -v . +ENV SGLANG_USE_CPU_ENGINE=1 ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2 WORKDIR /sgl-workspace/sglang diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index ba34dc8e4..39c1c2681 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,9 +1,11 @@ from torch import nn -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip _is_cuda = is_cuda() _is_hip = is_hip() +_is_cpu = is_cpu() +_is_cpu_amx_available = cpu_has_amx_support() class CustomOp(nn.Module): @@ -75,5 +77,7 @@ class CustomOp(nn.Module): return self.forward_cuda elif _is_hip: return self.forward_hip + elif _is_cpu and _is_cpu_amx_available: + return self.forward_cpu else: return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index b018743bc..a9e3436a6 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -29,11 +29,19 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import ( + cpu_has_amx_support, + is_cpu, + is_cuda, + is_npu, + set_weight_attrs, +) from sglang.utils import resolve_obj_by_qualname _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul @@ -53,6 +61,15 @@ class SiluAndMul(CustomOp): silu_and_mul(x, out) return out + def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: + if _is_cpu_amx_available: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) + return out + else: + return self.forward_native(x) + class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): @@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return nn.Identity() -if not _is_cuda and not _is_npu: +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): logger.info( - "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3ccff5a72..2277a70af 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,12 +20,21 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_cuda, + is_hip, + is_npu, +) _is_cuda = is_cuda() _is_hip = is_hip() _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import ( @@ -122,6 +131,23 @@ class RMSNorm(CustomOp): else: return x, residual + def forward_cpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if _is_cpu_amx_available: + if residual is not None: + torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( + x, residual, self.weight.data, self.variance_epsilon + ) + return x, residual + return torch.ops.sgl_kernel.rmsnorm_cpu( + x, self.weight.data, self.variance_epsilon + ) + else: + return self.forward_native(x, residual) + class GemmaRMSNorm(CustomOp): def __init__( @@ -188,7 +214,7 @@ class Gemma3RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" -if not (_is_cuda or _is_hip or _is_npu): +if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)): logger.info( "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index e8d3b58ce..f7690cb86 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import ( sglang_per_token_group_quant_int8, ) from sglang.srt.utils import ( + cpu_has_amx_support, direct_register_custom_op, get_bool_env_var, get_device_name, + is_cpu, is_cuda, is_hip, log_info_on_rank0, @@ -36,9 +38,13 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul +elif _is_cpu and _is_cpu_amx_available: + pass else: from vllm import _custom_ops as vllm_ops from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 96b89340c..7cf8de28a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -241,7 +241,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_fused_shared_experts: int = 0, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return moe_forward_native( layer, @@ -260,7 +264,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def forward_tpu(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("The TPU backend currently does not support MoE.") - forward_native = forward_cuda + forward_native = forward_cpu class FusedMoE(torch.nn.Module): diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 0c3d92b66..24d046325 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import ( topk_ids_logical_to_physical, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip +from sglang.srt.utils import ( + cpu_has_amx_support, + get_compiler_backend, + is_cpu, + is_cuda, + is_hip, +) _is_cuda = is_cuda() _is_hip = is_hip() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import moe_fused_gate @@ -40,7 +48,7 @@ if _is_cuda or _is_hip: from sgl_kernel import topk_softmax -def fused_topk_native( +def fused_topk_torch_native( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, @@ -61,6 +69,20 @@ def fused_topk_native( return topk_weights, topk_ids +def fused_topk_cpu( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + return torch.ops.sgl_kernel.topk_softmax_cpu( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + ) + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -115,7 +137,7 @@ def _fused_topk_postprocess( # This is used by the Deepseek V2/V3/R1 series models @torch.compile(dynamic=True, backend=get_compiler_backend()) -def grouped_topk( +def grouped_topk_gpu( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, @@ -171,6 +193,32 @@ def grouped_topk( return topk_weights, topk_ids +def grouped_topk_cpu( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, +): + assert expert_location_dispatch_info is None + return torch.ops.sgl_kernel.grouped_topk_cpu( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + num_fused_shared_experts, + routed_scaling_factor, + num_token_non_padded, + ) + + def biased_grouped_topk_impl( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess( return topk_ids -def biased_grouped_topk( +def biased_grouped_topk_gpu( hidden_states: torch.Tensor, gating_output: torch.Tensor, correction_bias: torch.Tensor, @@ -322,6 +370,45 @@ def biased_grouped_topk( ) +def biased_grouped_topk_cpu( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + compiled: bool = True, + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, +): + assert expert_location_dispatch_info is None + return torch.ops.sgl_kernel.biased_grouped_topk_cpu( + hidden_states, + gating_output, + correction_bias, + topk, + renormalize, + num_expert_group, + topk_group, + num_fused_shared_experts, + routed_scaling_factor, + num_token_non_padded, + ) + + +if _is_cpu and _is_cpu_amx_available: + biased_grouped_topk = biased_grouped_topk_cpu + grouped_topk = grouped_topk_cpu + fused_topk_native = fused_topk_cpu +else: + biased_grouped_topk = biased_grouped_topk_gpu + grouped_topk = grouped_topk_gpu + fused_topk_native = fused_topk_torch_native + + def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ee08c1f55..b471184d2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -14,15 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( all_close_1d, + cpu_has_amx_support, per_tensor_dequantize, replace_parameter, ) -from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() -if not _is_cuda and not _is_npu: +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): from vllm import _custom_ops as vllm_ops from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bbdd64ba8..36807aeda 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -64,7 +64,9 @@ from sglang.srt.layers.quantization.utils import ( ) from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.utils import ( + cpu_has_amx_support, get_bool_env_var, + is_cpu, is_cuda, is_hip, is_npu, @@ -76,6 +78,8 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() _is_fp8_fnuz = is_fp8_fnuz() @@ -88,7 +92,7 @@ if _is_hip: from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages from aiter.ops.shuffle import shuffle_weight -if not _is_cuda and not _is_npu: +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 0b9a15560..40a381f3b 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -6,12 +6,14 @@ from typing import List, Mapping, Tuple, Union import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant -from sglang.srt.utils import is_cuda, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() -if not _is_cuda and not _is_npu: +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7db99d375..bd145a4b0 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -8,11 +8,13 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda, is_hip, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu _is_cuda = is_cuda() _is_hip = is_hip() _is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace @@ -85,7 +87,9 @@ class RotaryEmbedding(CustomOp): if not _is_cuda: cache = cache.to(dtype) - if not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]: + if ( + not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512] + ) and not (_is_cpu and _is_cpu_amx_available): from vllm._custom_ops import rotary_embedding self.vllm_rotary_embedding = rotary_embedding @@ -148,6 +152,26 @@ class RotaryEmbedding(CustomOp): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_cpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + positions = torch.add(positions, offsets) if offsets is not None else positions + if _is_cpu_amx_available: + return torch.ops.sgl_kernel.rotary_embedding_cpu( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + else: + return self.forward_native(positions, query, key, offsets) + def forward_cuda( self, positions: torch.Tensor, @@ -697,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): key = key_rot return query.to(dtype), key.to(dtype) + def forward_cpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + positions = torch.add(positions, offsets) if offsets is not None else positions + if _is_cpu_amx_available: + return torch.ops.sgl_kernel.rotary_embedding_cpu( + positions, query, key, self.head_size, self.cos_sin_cache, False + ) + else: + return self.forward_native(positions, query, key, offsets) + class Llama3RotaryEmbedding(RotaryEmbedding): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ec546280b..6ffb0aed1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -111,6 +111,7 @@ from sglang.srt.utils import ( ) _is_hip = is_hip() +_is_cpu_amx_available = cpu_has_amx_support() # Use a small KV cache pool size for tests in CI SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) @@ -302,7 +303,7 @@ class ModelRunner: if ( server_args.attention_backend == "intel_amx" and server_args.device == "cpu" - and not cpu_has_amx_support() + and not _is_cpu_amx_available ): logger.info( "The current platform does not support Intel AMX, will fallback to torch_native backend." diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 017c103ba..b21b441cf 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import ( block_dequant as int8_block_dequant, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -95,8 +95,10 @@ from sglang.srt.utils import ( LazyValue, add_prefix, bind_or_assign, + cpu_has_amx_support, get_bool_env_var, get_int_env_var, + is_cpu, is_cuda, is_hip, is_non_idle_and_non_empty, @@ -107,9 +109,13 @@ _is_hip = is_hip() _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 +elif _is_cpu and _is_cpu_amx_available: + pass else: from vllm._custom_ops import awq_dequantize @@ -665,13 +671,14 @@ class DeepseekV2AttentionMLA(nn.Module): if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( + self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, + device=global_server_args_dict["device"], ) if rope_scaling: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 11ff3b555..048f1f493 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -160,7 +160,7 @@ def is_npu() -> bool: return hasattr(torch, "npu") and torch.npu.is_available() -def is_cpu() -> bool: +def is_host_cpu_x86() -> bool: machine = platform.machine().lower() return ( machine in ("x86_64", "amd64", "i386", "i686") @@ -169,6 +169,10 @@ def is_cpu() -> bool: ) +def is_cpu() -> bool: + return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86() + + def is_flashinfer_available(): """ Check whether flashinfer is available. @@ -1452,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str: "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'." ) + if is_cpu(): + if cpu_has_amx_support(): + logger.info("Intel AMX is detected, using CPU with Intel AMX support.") + else: + logger.warning( + "CPU device enabled, using torch native backend, low performance expected." + ) + return "cpu" + raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.") diff --git a/test/srt/cpu/test_activation.py b/test/srt/cpu/test_activation.py index 7602445dd..a5bfce912 100644 --- a/test/srt/cpu/test_activation.py +++ b/test/srt/cpu/test_activation.py @@ -21,7 +21,7 @@ class TestActivation(CustomTestCase): ref_out = SiluAndMul(x) atol = rtol = precision[ref_out.dtype] - self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def test_activation(self): for params in itertools.product(self.M, self.N, self.dtype): diff --git a/test/srt/cpu/test_gemm.py b/test/srt/cpu/test_gemm.py index bb4094f0d..c7ee838d9 100644 --- a/test/srt/cpu/test_gemm.py +++ b/test/srt/cpu/test_gemm.py @@ -60,8 +60,8 @@ class TestGemm(CustomTestCase): ) atol = rtol = precision[ref.dtype] - self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol) def test_bf16_gemm(self): for params in itertools.product( @@ -100,13 +100,13 @@ class TestGemm(CustomTestCase): out = torch.ops.sgl_kernel.int8_scaled_mm_cpu( Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False ) - self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) # test the fused version fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant( A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False ) - self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol) def test_int8_gemm(self): for params in itertools.product( @@ -165,7 +165,7 @@ class TestGemm(CustomTestCase): prepack, ) atol = rtol = precision[ref.dtype] - self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol) def test_fp8_gemm(self): for params in itertools.product( diff --git a/test/srt/cpu/test_moe.py b/test/srt/cpu/test_moe.py index 62542e366..c5852408c 100644 --- a/test/srt/cpu/test_moe.py +++ b/test/srt/cpu/test_moe.py @@ -91,9 +91,7 @@ class TestFusedExperts(CustomTestCase): fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack) atol = rtol = precision[torch_output.dtype] - self.assertTrue( - torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol) - ) + torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol) def test_bf16_moe(self): for params in itertools.product( @@ -171,7 +169,7 @@ class TestFusedExperts(CustomTestCase): # Increase the tolerance for large input shapes if M > 35: atol = rtol = 0.02 - self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def test_int8_moe(self): for params in itertools.product( @@ -235,7 +233,7 @@ class TestFusedExperts(CustomTestCase): ) atol = rtol = precision[dtype] - self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol) def test_fp8_moe(self): for params in itertools.product( diff --git a/test/srt/cpu/test_norm.py b/test/srt/cpu/test_norm.py index fa4530afd..6f1065d61 100644 --- a/test/srt/cpu/test_norm.py +++ b/test/srt/cpu/test_norm.py @@ -47,7 +47,7 @@ class TestNorm(CustomTestCase): ref_out = self._forward_native(x, weight, variance_epsilon) atol = rtol = precision[ref_out.dtype] - self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) ref_x = x.clone() residual = torch.randn([m, hidden_size], dtype=dtype) @@ -61,8 +61,8 @@ class TestNorm(CustomTestCase): ref_x, weight, variance_epsilon, ref_residual ) - self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol)) + torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) + torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) def _l2norm_test(self, m, n, dtype): @@ -75,7 +75,7 @@ class TestNorm(CustomTestCase): ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon) atol = rtol = precision[ref_out.dtype] - self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def test_norm(self): for params in itertools.product(self.M, self.N, self.dtype): diff --git a/test/srt/cpu/test_qkv_proj_with_rope.py b/test/srt/cpu/test_qkv_proj_with_rope.py index 9d4b80f6a..4496bb8ed 100644 --- a/test/srt/cpu/test_qkv_proj_with_rope.py +++ b/test/srt/cpu/test_qkv_proj_with_rope.py @@ -211,12 +211,12 @@ class TestQKVProjWithROPE(CustomTestCase): qk_rope_head_dim, ) atol = rtol = precision[q_ref.dtype] - self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(fused_q_out, q_out)) - self.assertTrue(torch.allclose(fused_k_out, k_out)) - self.assertTrue(torch.allclose(fused_v_out, v_out)) + torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol) + torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) + torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_q_out, q_out) + torch.testing.assert_close(fused_k_out, k_out) + torch.testing.assert_close(fused_v_out, v_out) def test_int8_qkv_proj_with_rope(self): dtype = torch.bfloat16 @@ -302,12 +302,12 @@ class TestQKVProjWithROPE(CustomTestCase): qk_rope_head_dim, ) atol = rtol = precision[q_ref.dtype] - self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(fused_q_out, q_out)) - self.assertTrue(torch.allclose(fused_k_out, k_out)) - self.assertTrue(torch.allclose(fused_v_out, v_out)) + torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol) + torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) + torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_q_out, q_out) + torch.testing.assert_close(fused_k_out, k_out) + torch.testing.assert_close(fused_v_out, v_out) def test_fp8_qkv_proj_with_rope(self): dtype = torch.bfloat16 diff --git a/test/srt/cpu/test_rope.py b/test/srt/cpu/test_rope.py index b9c5da42b..be4e63c44 100644 --- a/test/srt/cpu/test_rope.py +++ b/test/srt/cpu/test_rope.py @@ -75,8 +75,8 @@ class TestROPE(CustomTestCase): ) atol = rtol = precision[q_pe.dtype] - self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol)) + torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol) + torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol) torch.testing.assert_close(k_pe, k_pe_clone) def test_origin_rope(self): diff --git a/test/srt/cpu/test_shared_expert.py b/test/srt/cpu/test_shared_expert.py index ea048495c..bf7840b53 100644 --- a/test/srt/cpu/test_shared_expert.py +++ b/test/srt/cpu/test_shared_expert.py @@ -71,7 +71,7 @@ class TestSharedExpert(CustomTestCase): ) atol = rtol = precision[ref.dtype] - self.assertTrue(torch.allclose(ref, res, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref, res, atol=atol, rtol=rtol) def test_bf16_shared_expert(self): for params in itertools.product( @@ -129,7 +129,7 @@ class TestSharedExpert(CustomTestCase): ) atol = rtol = precision[ref2.dtype] - self.assertTrue(torch.allclose(ref2, res2, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol) def test_int8_shared_expert(self): for params in itertools.product( @@ -199,7 +199,7 @@ class TestSharedExpert(CustomTestCase): ) atol = rtol = precision[ref_out.dtype] - self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def test_fp8_shared_expert(self): for params in itertools.product( diff --git a/test/srt/cpu/test_topk.py b/test/srt/cpu/test_topk.py index 420f6cbb7..3e08794d7 100644 --- a/test/srt/cpu/test_topk.py +++ b/test/srt/cpu/test_topk.py @@ -8,8 +8,8 @@ from utils import precision from sglang.srt.layers.moe.topk import ( biased_grouped_topk_impl as native_biased_grouped_topk, ) -from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk -from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk +from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk +from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk from sglang.srt.models.llama4 import Llama4MoE from sglang.test.test_utils import CustomTestCase diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a2e39b9aa..06b09f9fa 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -175,9 +175,11 @@ suites = { TestFile("cpu/test_decode.py"), TestFile("cpu/test_extend.py"), TestFile("cpu/test_gemm.py"), + TestFile("cpu/test_mla.py"), TestFile("cpu/test_moe.py"), TestFile("cpu/test_norm.py"), TestFile("cpu/test_qkv_proj_with_rope.py"), + TestFile("cpu/test_rope.py"), TestFile("cpu/test_shared_expert.py"), ], "nightly": [