diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 3ea6b4b2f..ab7350555 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda _is_cuda = is_cuda() if _is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_fp8, + sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8, ) logger = logging.getLogger(__name__) @@ -654,10 +654,7 @@ def grouped_gemm_triton( if block_shape is not None: assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] - if _is_cuda: - a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) - else: - a, scale_a = per_token_group_quant_fp8(a, block_k) + a, scale_a = per_token_group_quant_fp8(a, block_k) assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 67496b14b..e8fd243e4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -10,16 +10,14 @@ import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( all_close_1d, - is_cuda, - is_fp8_fnuz, per_tensor_dequantize, replace_parameter, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import is_cuda, set_weight_attrs _is_cuda = is_cuda() diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ed13195c5..210a24f69 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import ( from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz, ) -from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale +from sglang.srt.layers.quantization.utils import requantize_with_max_scale __all__ = ["CompressedTensorsW8A8Fp8"] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 2a5de2f21..b5fdccb88 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + is_fp8_fnuz, per_token_group_quant_fp8, scaled_fp8_quant, ) @@ -71,6 +73,11 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() +_is_fp8_fnuz = is_fp8_fnuz() + +use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") +use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") + if _is_hip: from aiter import ActivationType, QuantType from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages @@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase): # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_hip: + if _is_fp8_fnuz: # activation_scheme: dynamic weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, weight_scale=layer.weight_scale_inv, input_scale=None, ) - layer.weight = torch.nn.Parameter(weight, requires_grad=False) - layer.weight_scale_inv = torch.nn.Parameter( - weight_scale, requires_grad=False - ) + layer.input_scale = None else: - layer.weight = torch.nn.Parameter( - layer.weight.data, requires_grad=False - ) - layer.weight_scale_inv = torch.nn.Parameter( - layer.weight_scale_inv.data, requires_grad=False - ) + weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + layer.weight_scale_inv = torch.nn.Parameter( + weight_scale, requires_grad=False + ) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase): weight = layer.weight weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_hip: + if _is_fp8_fnuz: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -482,11 +485,7 @@ class Fp8MoEMethod: from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = ( - torch.uint32 - if get_bool_env_var("SGLANG_INT4_WEIGHT") - else torch.float8_e4m3fn - ) + params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn tp_size = get_tensor_model_parallel_world_size() if self.block_quant: block_n, block_k = ( @@ -511,7 +510,7 @@ class Fp8MoEMethod: ) # WEIGHTS - if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): + if _is_hip and use_hip_int4: # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -583,9 +582,7 @@ class Fp8MoEMethod: layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - if ( - _is_hip - ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel + if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling w13_weight_scale1 = torch.nn.Parameter( torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), @@ -612,7 +609,7 @@ class Fp8MoEMethod: set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): + if _is_hip and use_hip_int4: extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) @@ -644,14 +641,14 @@ class Fp8MoEMethod: layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"): + if _is_hip and use_hip_int4: self.process_weights_hip_int4(layer) return # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_hip: + if _is_fp8_fnuz: # activation_scheme: dynamic w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.w13_weight, @@ -675,20 +672,19 @@ class Fp8MoEMethod: ) layer.w2_input_scale = None - if get_bool_env_var("SGLANG_AITER_MOE"): - # Pre-shuffle weights - layer.w13_weight.data = shuffle_weight( - layer.w13_weight.contiguous(), (16, 16) - ) - layer.w2_weight.data = shuffle_weight( - layer.w2_weight.contiguous(), (16, 16) - ) + if _is_hip and use_aiter_moe: + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) return # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If ROCm, use float8_e4m3fnuz instead (MI300x HW) - fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + # If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -742,7 +738,7 @@ class Fp8MoEMethod: ) # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_hip: + if _is_fp8_fnuz: # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -798,7 +794,7 @@ class Fp8MoEMethod: return def process_weights_hip_int4(self, layer: Module): - # TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added + # TODO: and use_aiter_moe: add after triton kernel added # INT4-FP8 (INT4 MoE Weight, FP8 Compute) # Weight Permutation layer.w13_weight = torch.nn.Parameter( @@ -845,7 +841,7 @@ class Fp8MoEMethod: padding_size, # Avoid circular import ) - if get_bool_env_var("SGLANG_AITER_MOE"): + if use_aiter_moe: layer.w13_weight = torch.nn.Parameter( shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, @@ -856,7 +852,7 @@ class Fp8MoEMethod: requires_grad=False, ) torch.cuda.empty_cache() - # ROCm (SGLANG_AITER_MOE): using column-wise scaling + # ROCm (use_aiter_moe): using column-wise scaling layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) elif get_bool_env_var("SGLANG_MOE_PADDING"): @@ -908,59 +904,16 @@ class Fp8MoEMethod: ) if _is_hip: - if get_bool_env_var("SGLANG_INT4_WEIGHT"): - # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE") - assert not no_combine, f"{no_combine=} is not supported." - return ck_moe_2stages( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - QuantType.per_Token, - layer.w13_weight_scale1, - layer.w2_weight_scale1, - activation=( - ActivationType.Silu - if activation == "silu" - else ActivationType.Gelu - ), - ) - - if get_bool_env_var("SGLANG_AITER_MOE"): - assert not no_combine, f"{no_combine=} is not supported." - if self.block_quant: - # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being. - assert ( - activation == "silu" - ), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE" - return asm_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - layer.w13_weight_scale_inv, - layer.w2_weight_scale_inv, - block_shape=tuple(self.quant_config.weight_block_size), - expert_mask=None, - ) - else: - return ck_moe_2stages( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - QuantType.per_Token, - layer.w13_weight_scale1, - layer.w2_weight_scale1, - activation=( - ActivationType.Silu - if activation == "silu" - else ActivationType.Gelu - ), - ) + ret = self.maybe_apply_hip_fused_experts( + layer, + x, + topk_weights, + topk_ids, + activation, + no_combine, + ) + if ret is not None: + return ret # Expert fusion with FP8 quantization return fused_experts( @@ -987,6 +940,68 @@ class Fp8MoEMethod: no_combine=no_combine, ) + def maybe_apply_hip_fused_experts( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + no_combine: bool = False, + ) -> Optional[torch.Tensor]: + if use_hip_int4: + # TODO: add triton kernel and add check use_aiter_moe + assert not no_combine, f"{no_combine=} is not supported." + return ck_moe_2stages( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + QuantType.per_Token, + layer.w13_weight_scale1, + layer.w2_weight_scale1, + activation=( + ActivationType.Silu if activation == "silu" else ActivationType.Gelu + ), + ) + + if use_aiter_moe: + assert not no_combine, f"{no_combine=} is not supported." + if self.block_quant: + # TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being. + assert ( + activation == "silu" + ), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe" + return asm_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, + block_shape=tuple(self.quant_config.weight_block_size), + expert_mask=None, + ) + else: + return ck_moe_2stages( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + QuantType.per_Token, + layer.w13_weight_scale1, + layer.w2_weight_scale1, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), + ) + return None + class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 45157527e..e52f69142 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -16,6 +16,7 @@ import functools import json import logging import os +from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple import torch @@ -34,12 +35,6 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() -_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn -if _is_hip: - fp8_max = 224.0 -else: - fp8_max = torch.finfo(_fp8_type).max -fp8_min = -fp8_max if _is_cuda: from sgl_kernel import ( @@ -54,6 +49,24 @@ if _is_cuda: logger = logging.getLogger(__name__) + +@lru_cache() +def is_fp8_fnuz() -> bool: + if _is_hip: + # only device 0 is checked, this assumes MI300 platforms are homogeneous + return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName + return False + + +if is_fp8_fnuz(): + fp8_dtype = torch.float8_e4m3fnuz + fp8_max = 224.0 +else: + fp8_dtype = torch.float8_e4m3fn + fp8_max = torch.finfo(fp8_dtype).max +fp8_min = -fp8_max + + if supports_custom_op(): def deep_gemm_fp8_fp8_bf16_nt( @@ -198,7 +211,7 @@ def per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) M = x.numel() // group_size N = group_size if column_major_scales: @@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) if column_major_scales: if scale_tma_aligned: # aligned to 4 * sizeof(float) @@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8( def sglang_per_token_quant_fp8( x: torch.Tensor, - dtype: torch.dtype = _fp8_type, + dtype: torch.dtype = fp8_dtype, ): assert x.is_contiguous(), "`x` is not contiguous" @@ -384,7 +397,7 @@ def static_quant_fp8( assert x.is_contiguous(), "`x` is not contiguous" assert x_s.numel() == 1, "only supports per-tensor scale" - x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) M = x.numel() // x.shape[-1] N = x.shape[-1] if repeat_scale: @@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs( return None +def select_w8a8_block_fp8_matmul_kernel(M, N, META): + return _w8a8_block_fp8_matmul + + +if _is_hip: + + def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META): + # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. + # Empirical testing shows the sweet spot lies when it's less than the # of + # compute units available on the device. + num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv( + N, META["BLOCK_SIZE_N"] + ) + num_workgroups <= get_device_core_count() + + def select_w8a8_block_fp8_matmul_kernel(M, N, META): + if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META): + return _w8a8_block_fp8_matmul_unrolledx4 + else: + return _w8a8_block_fp8_matmul + + def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul( C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) - configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Default config - # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3, - } - - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - - # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. - # Empirical testing shows the sweet spot lies when it's less than the # of - # compute units available on the device. - num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( - N, config["BLOCK_SIZE_N"] - ) - # deepgemm only support bf16 if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: if supports_custom_op(): @@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul( else: deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) else: - kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 - if (_is_hip == True and num_workgroups <= get_device_core_count()) - else _w8a8_block_fp8_matmul - ) + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config) kernel[grid]( A, @@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8( and x_s_out.device == x.device ) - x_q = x.new_empty(x.size(), dtype=_fp8_type) + x_q = x.new_empty(x.size(), dtype=fp8_dtype) num_head, num_seq, head_size = x.shape BLOCK_SIZE = triton.next_power_of_2(head_size) @@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8( tl.store(y_s_ptr + gid * y_s_stride_g, y_s) -def per_tensor_quant_mla_deep_gemm_masked_fp8( +def per_token_group_quant_mla_deep_gemm_masked_fp8( x: torch.Tensor, group_size: int = 128, eps: float = 1e-12, - dtype: torch.dtype = torch.float8_e4m3fn, + dtype: torch.dtype = fp8_dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: """ This function quantizes input values to float8 values with per-token-group-quantization @@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8( """ assert x.dim() == 3, "`x` is not a 3d-tensor" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - if _is_hip: - dtype = torch.float8_e4m3fnuz - fp8_max = 224.0 - b, m, k = x.shape aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel num_tiles_k = k // group_size @@ -1043,10 +1062,9 @@ def scaled_fp8_quant( """ assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" shape = input.shape - out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + output = torch.empty(shape, device=input.device, dtype=fp8_dtype) if scale is None: # Dynamic scaling diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c137d3ad1..aeab9d48d 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -14,6 +14,9 @@ except ImportError: from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + fp8_max, + is_fp8_fnuz, per_token_group_quant_fp8, scaled_fp8_quant, sglang_per_token_quant_fp8, @@ -30,8 +33,11 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() +_is_fp8_fnuz = is_fp8_fnuz() -if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): +use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") + +if _is_hip and use_aiter_moe: from aiter import gemm_a8w8_blockscale if _is_cuda: @@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = None -_TORCH_VERSION = torch.__version__.split("+")[0] -try: - _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3])) -except ValueError: - _TORCH_VERSION_TUPLE = (0, 0, 0) -# The condition to determine if it is on a platform that supports -# torch._scaled_mm rowwise feature. -# The condition is determined once as the operations -# are time consuming. -USE_ROWWISE_TORCH_SCALED_MM = ( - _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0) -) +def use_rowwise_torch_scaled_mm(): + _TORCH_VERSION = torch.__version__.split("+")[0] + try: + _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3])) + except ValueError: + _TORCH_VERSION_TUPLE = (0, 0, 0) + if _is_hip: + # The condition to determine if it is on a platform that supports + # torch._scaled_mm rowwise feature. + # The condition is determined once as the operations + # are time consuming. + return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0) + return False + + +USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm() def cutlass_fp8_supported(): @@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear( output = fp8_blockwise_scaled_mm( q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype ) - elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): + elif _is_hip and use_aiter_moe: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=False ) @@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear( def input_to_float8( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + x: torch.Tensor, dtype: torch.dtype = fp8_dtype ) -> Tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values with tensor-wise quantization.""" - finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12) - fp8_max = finfo.max - if _is_hip: - dtype = torch.float8_e4m3fnuz - fp8_max = 224.0 - scale = fp8_max / amax - x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max) + + if _is_fp8_fnuz: + dtype = fp8_dtype + fp_max = fp8_max + else: + finfo = torch.finfo(dtype) + fp_max = finfo.max + + scale = fp_max / amax + x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() diff --git a/python/sglang/srt/layers/quantization/kv_cache.py b/python/sglang/srt/layers/quantization/kv_cache.py index da6d91a9b..7da2dac17 100644 --- a/python/sglang/srt/layers/quantization/kv_cache.py +++ b/python/sglang/srt/layers/quantization/kv_cache.py @@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import is_hip - -_is_hip = is_hip() logger = logging.getLogger(__name__) @@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase): torch.tensor(-1.0, dtype=torch.float32), requires_grad=False ) - @classmethod - def is_fp8_fnuz(cls) -> bool: - # only device 0 is checked, this assumes MI300 platforms are homogeneous - return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName - def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") @@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() - if _is_hip and self.is_fp8_fnuz(): + if is_fp8_fnuz(): k_scale *= 2 v_scale *= 2 elif layer.k_scale < 0.0 and layer.v_scale < 0.0: @@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() - if _is_hip and self.is_fp8_fnuz(): + if is_fp8_fnuz(): k_scale *= 2 v_scale *= 2 diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index fdea997e3..d2bbce494 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -14,11 +14,6 @@ if not _is_cuda: from vllm._custom_ops import scaled_fp8_quant -def is_fp8_fnuz() -> bool: - # only device 0 is checked, this assumes MI300 platforms are homogeneous - return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName - - def is_layer_skipped( prefix: str, ignored_layers: List[str], diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 48cf5db34..26a3259e8 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + is_fp8_fnuz, + per_token_group_quant_fp8, +) from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) -from sglang.srt.utils import is_hip, set_weight_attrs +from sglang.srt.utils import set_weight_attrs -_is_hip = is_hip() +_is_fp8_fnuz = is_fp8_fnuz() class W8A8Fp8Config(QuantizationConfig): @@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): if self.quantization_config.is_checkpoint_fp8_serialized: weight_scale = layer.weight_scale.detach() # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly. - if _is_hip: + if _is_fp8_fnuz: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale ) @@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase): layer.weight, layer.weight.shape[-1] ) weight_scale = weight_scale.t().contiguous() - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale - ) else: # if cutlass not supported, we fall back to use torch._scaled_mm # which requires per tensor quantization on weight - fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype) # Update the layer with the new values. @@ -227,7 +226,6 @@ class W8A8FP8MoEMethod: ): from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index d01bc3ae9..8f55d8408 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import EPMoE -from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8_utils import ( - block_quant_to_tensor_quant, - normalize_e4m3fn_to_e4m3fnuz, -) -from sglang.srt.layers.quantization.int8_utils import ( - block_dequant as int8_block_dequant, -) from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM -from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip - -_is_hip = is_hip() -_is_cuda = is_cuda() - -if _is_cuda: - from sgl_kernel import awq_dequantize -else: - from vllm._custom_ops import awq_dequantize - +from sglang.srt.utils import BumpAllocator, add_prefix logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 35c19e14b..2d5906f80 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( - per_tensor_quant_mla_deep_gemm_masked_fp8, per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, ) from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, @@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module): if self.use_deep_gemm_bmm: q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( - per_tensor_quant_mla_deep_gemm_masked_fp8( - q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn - ) + per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1)) ) q_nope_out = q_nope.new_empty( (self.num_local_heads, aligned_m, self.kv_lora_rank) @@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module): if self.use_deep_gemm_bmm: attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( - per_tensor_quant_mla_deep_gemm_masked_fp8( - attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn + per_token_group_quant_mla_deep_gemm_masked_fp8( + attn_output.transpose(0, 1) ) ) attn_bmm_output = attn_output.new_empty( diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 117acf3a1..b331f5a87 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,9 +7,9 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( - per_tensor_quant_mla_deep_gemm_masked_fp8, per_tensor_quant_mla_fp8, per_token_group_quant_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, static_quant_fp8, w8a8_block_fp8_matmul, ) @@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase): with torch.inference_mode(): ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12) - out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8( + out, scale, _, _, _ = per_token_group_quant_mla_deep_gemm_masked_fp8( x, group_size ) out = out[:, :num_tokens, :]