diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 77d849f3f..83f74fb27 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -54,14 +54,11 @@ _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -if not _is_npu: +if not (_is_npu or _is_hip): from sgl_kernel import silu_and_mul from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe -if _is_hip: - from vllm._custom_ops import scaled_fp8_quant - if _use_aiter: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 246606746..9c13c7e9d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -39,11 +39,20 @@ _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul elif _is_cpu and _is_cpu_amx_available: pass +elif _is_hip: + from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul + + if _use_aiter: + try: + from aiter import moe_sum + except ImportError: + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") else: from vllm import _custom_ops as vllm_ops from vllm._custom_ops import scaled_fp8_quant @@ -1521,11 +1530,7 @@ def fused_experts_impl( routed_scaling_factor: Optional[float] = None, ): padded_size = padding_size - if ( - not (use_fp8_w8a8 or use_int8_w8a8) - or block_shape is not None - or (_is_hip and get_bool_env_var("SGLANG_USE_AITER")) - ): + if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: padded_size = 0 # Check constraints. @@ -1723,6 +1728,17 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], routed_scaling_factor, ) + elif _is_hip: + if _use_aiter: + moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + else: + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) else: vllm_ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 39e5f9e25..af1f6cbf7 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import ( per_tensor_dequantize, replace_parameter, ) -from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput @@ -32,8 +32,9 @@ _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_hip = is_hip() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): from vllm import _custom_ops as vllm_ops from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 23daa5d26..6fa3ccc59 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_hip and (_use_aiter or _use_hip_int4): from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe - from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages from aiter.ops.shuffle import shuffle_weight -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 79504265c..b488a65c0 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.utils import ( align, direct_register_custom_op, + get_bool_env_var, get_device_core_count, get_device_name, is_cpu, @@ -39,6 +40,7 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import ( @@ -47,6 +49,22 @@ if _is_cuda: sgl_per_token_quant_fp8, ) +if _is_hip: + if _use_aiter: + try: + from aiter import ( # v0.1.3 + dynamic_per_tensor_quant, + dynamic_per_token_scaled_quant, + static_per_tensor_quant, + ) + except ImportError: + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") + else: + try: + import vllm._C + except ImportError: + raise ImportError("vllm is required when SGLANG_USE_AITER is set to False") + logger = logging.getLogger(__name__) @@ -1116,58 +1134,109 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8( return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - use_per_token_if_dynamic: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 (8-bit floating point) format. +""" +Quantize input tensor to FP8 (8-bit floating point) format. - Args: - input (torch.Tensor): Input tensor to be quantized - scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. - If None, scales will be computed dynamically. - num_token_padding (Optional[int]): If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), - determines the quantization granularity: - - True: compute scale per token - - False: compute single scale per tensor +Args: + input (torch.Tensor): Input tensor to be quantized + scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. + If None, scales will be computed dynamically. + num_token_padding (Optional[int]): If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), + determines the quantization granularity: + - True: compute scale per token + - False: compute single scale per tensor - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - quantized_tensor: The FP8 quantized version of input - - scale_tensor: The scaling factors used for quantization +Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - quantized_tensor: The FP8 quantized version of input + - scale_tensor: The scaling factors used for quantization - Raises: - AssertionError: If input is not 2D or if static scale's numel != 1 - """ - assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" - shape = input.shape - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=fp8_dtype) +Raises: + AssertionError: If input is not 2D or if static scale's numel != 1 +""" +if _is_hip: - if scale is None: - # Dynamic scaling - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - sgl_per_token_quant_fp8(input, output, scale) + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + use_per_token_if_dynamic: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" + shape = input.shape + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=fp8_dtype) + + if scale is None: + # Dynamic scaling + if use_per_token_if_dynamic: + scale = torch.empty( + (shape[0], 1), device=input.device, dtype=torch.float32 + ) + if _use_aiter: + dynamic_per_token_scaled_quant(output, input, scale) + else: + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, None + ) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + if _use_aiter: + dynamic_per_tensor_quant(output, input, scale) + else: + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - sgl_per_tensor_quant_fp8( - input, output, scale, is_static=False - ) # False for dynamic - else: - # Static scaling - assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}" - sgl_per_tensor_quant_fp8( - input, output, scale, is_static=True - ) # True for static + # Static scaling + assert ( + scale.numel() == 1 + ), f"Expected scalar scale, got numel={scale.numel()}" + if _use_aiter: + static_per_tensor_quant(output, input, scale) + else: + torch.ops._C.static_scaled_fp8_quant(output, input, scale) - return output, scale + return output, scale + +else: + + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + use_per_token_if_dynamic: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + + assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" + shape = input.shape + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=fp8_dtype) + + if scale is None: + # Dynamic scaling + if use_per_token_if_dynamic: + scale = torch.empty( + (shape[0], 1), device=input.device, dtype=torch.float32 + ) + sgl_per_token_quant_fp8(input, output, scale) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=False + ) # False for dynamic + else: + # Static scaling + assert ( + scale.numel() == 1 + ), f"Expected scalar scale, got numel={scale.numel()}" + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=True + ) # True for static + + return output, scale fp8_autotune = triton.autotune( diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index fa4cbf582..ddafcc6f5 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: from aiter import ActivationType from aiter.fused_moe import fused_moe - from aiter.fused_moe_bf16_asm import ck_moe_2stages from aiter.ops.shuffle import shuffle_weight diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 89e0eb84a..8904247a6 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -12,7 +12,7 @@ import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -21,8 +21,9 @@ _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_hip = is_hip() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/test/test_custom_ops.py b/python/sglang/test/test_custom_ops.py index 873f9960e..c07c95db6 100644 --- a/python/sglang/test/test_custom_ops.py +++ b/python/sglang/test/test_custom_ops.py @@ -3,8 +3,13 @@ import pytest import torch -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant -from sglang.srt.utils import is_cuda +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant +from sglang.srt.utils import is_cuda, is_hip + +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_fp8_fnuz = is_fp8_fnuz() +fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None: def quantize_ref_per_tensor(tensor, inv_scale): # The reference implementation that fully aligns to # the kernel being tested. - finfo = torch.finfo(torch.float8_e4m3fn) + finfo = torch.finfo(fp8_dtype) scale = inv_scale.reciprocal() qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) - qweight = qweight.to(torch.float8_e4m3fn) + qweight = qweight.to(fp8_dtype) return qweight def dequantize_per_tensor(tensor, inv_scale, dtype): @@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None: ) -if is_cuda: +if _is_cuda or _is_hip: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None: def quantize_ref_per_token(tensor, inv_scale): # The reference implementation that fully aligns to # the kernel being tested. - finfo = torch.finfo(torch.float8_e4m3fn) + finfo = torch.finfo(fp8_dtype) scale = inv_scale.reciprocal() qweight = (tensor.to(torch.float32) * scale).clamp( min=finfo.min, max=finfo.max ) - qweight = qweight.to(torch.float8_e4m3fn) + qweight = qweight.to(fp8_dtype) return qweight def dequantize_per_token(tensor, inv_scale, dtype):