[refactor] slightly tidy fp8 module (#5993)
This commit is contained in:
@@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -654,10 +654,7 @@ def grouped_gemm_triton(
|
||||
if block_shape is not None:
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
if _is_cuda:
|
||||
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
|
||||
else:
|
||||
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
||||
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
||||
|
||||
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
||||
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
||||
|
||||
@@ -10,16 +10,14 @@ import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
is_cuda,
|
||||
is_fp8_fnuz,
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
|
||||
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
fp8_dtype,
|
||||
is_fp8_fnuz,
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
)
|
||||
@@ -71,6 +73,11 @@ from sglang.srt.utils import (
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||
|
||||
if _is_hip:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
||||
@@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_hip:
|
||||
if _is_fp8_fnuz:
|
||||
# activation_scheme: dynamic
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
weight_scale, requires_grad=False
|
||||
)
|
||||
|
||||
layer.input_scale = None
|
||||
else:
|
||||
layer.weight = torch.nn.Parameter(
|
||||
layer.weight.data, requires_grad=False
|
||||
)
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
layer.weight_scale_inv.data, requires_grad=False
|
||||
)
|
||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
weight_scale, requires_grad=False
|
||||
)
|
||||
return
|
||||
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
@@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_hip:
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
@@ -482,11 +485,7 @@ class Fp8MoEMethod:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = (
|
||||
torch.uint32
|
||||
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||
else torch.float8_e4m3fn
|
||||
)
|
||||
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.block_quant:
|
||||
block_n, block_k = (
|
||||
@@ -511,7 +510,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||
if _is_hip and use_hip_int4:
|
||||
# INT4 MoE weight - INT32 packed
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -583,9 +582,7 @@ class Fp8MoEMethod:
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if (
|
||||
_is_hip
|
||||
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
||||
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
|
||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||
w13_weight_scale1 = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||
@@ -612,7 +609,7 @@ class Fp8MoEMethod:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||
if _is_hip and use_hip_int4:
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
@@ -644,14 +641,14 @@ class Fp8MoEMethod:
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||
if _is_hip and use_hip_int4:
|
||||
self.process_weights_hip_int4(layer)
|
||||
return
|
||||
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_hip:
|
||||
if _is_fp8_fnuz:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
@@ -675,20 +672,19 @@ class Fp8MoEMethod:
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
|
||||
if get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
# Pre-shuffle weights
|
||||
layer.w13_weight.data = shuffle_weight(
|
||||
layer.w13_weight.contiguous(), (16, 16)
|
||||
)
|
||||
layer.w2_weight.data = shuffle_weight(
|
||||
layer.w2_weight.contiguous(), (16, 16)
|
||||
)
|
||||
if _is_hip and use_aiter_moe:
|
||||
# Pre-shuffle weights
|
||||
layer.w13_weight.data = shuffle_weight(
|
||||
layer.w13_weight.contiguous(), (16, 16)
|
||||
)
|
||||
layer.w2_weight.data = shuffle_weight(
|
||||
layer.w2_weight.contiguous(), (16, 16)
|
||||
)
|
||||
return
|
||||
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
@@ -742,7 +738,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_hip:
|
||||
if _is_fp8_fnuz:
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@@ -798,7 +794,7 @@ class Fp8MoEMethod:
|
||||
return
|
||||
|
||||
def process_weights_hip_int4(self, layer: Module):
|
||||
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
|
||||
# TODO: and use_aiter_moe: add after triton kernel added
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||
# Weight Permutation
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
@@ -845,7 +841,7 @@ class Fp8MoEMethod:
|
||||
padding_size, # Avoid circular import
|
||||
)
|
||||
|
||||
if get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
if use_aiter_moe:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
@@ -856,7 +852,7 @@ class Fp8MoEMethod:
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
|
||||
# ROCm (use_aiter_moe): using column-wise scaling
|
||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
||||
@@ -908,59 +904,16 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
QuantType.per_Token,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
|
||||
if get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
block_shape=tuple(self.quant_config.weight_block_size),
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
QuantType.per_Token,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
ret = self.maybe_apply_hip_fused_experts(
|
||||
layer,
|
||||
x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
no_combine,
|
||||
)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
@@ -987,6 +940,68 @@ class Fp8MoEMethod:
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
def maybe_apply_hip_fused_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
no_combine: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if use_hip_int4:
|
||||
# TODO: add triton kernel and add check use_aiter_moe
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
QuantType.per_Token,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
|
||||
if use_aiter_moe:
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
block_shape=tuple(self.quant_config.weight_block_size),
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
QuantType.per_Token,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
||||
@@ -16,6 +16,7 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -34,12 +35,6 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
else:
|
||||
fp8_max = torch.finfo(_fp8_type).max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -54,6 +49,24 @@ if _is_cuda:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_fp8_fnuz() -> bool:
|
||||
if _is_hip:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
return False
|
||||
|
||||
|
||||
if is_fp8_fnuz():
|
||||
fp8_dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
else:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
fp8_max = torch.finfo(fp8_dtype).max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt(
|
||||
@@ -198,7 +211,7 @@ def per_token_group_quant_fp8(
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
@@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8(
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
if column_major_scales:
|
||||
if scale_tma_aligned:
|
||||
# aligned to 4 * sizeof(float)
|
||||
@@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8(
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = _fp8_type,
|
||||
dtype: torch.dtype = fp8_dtype,
|
||||
):
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
@@ -384,7 +397,7 @@ def static_quant_fp8(
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
if repeat_scale:
|
||||
@@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs(
|
||||
return None
|
||||
|
||||
|
||||
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
|
||||
return _w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
if _is_hip:
|
||||
|
||||
def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, META["BLOCK_SIZE_N"]
|
||||
)
|
||||
num_workgroups <= get_device_core_count()
|
||||
|
||||
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
|
||||
if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
|
||||
return _w8a8_block_fp8_matmul_unrolledx4
|
||||
else:
|
||||
return _w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul(
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
# deepgemm only support bf16
|
||||
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
||||
if supports_custom_op():
|
||||
@@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul(
|
||||
else:
|
||||
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
@@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8(
|
||||
and x_s_out.device == x.device
|
||||
)
|
||||
|
||||
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
||||
x_q = x.new_empty(x.size(), dtype=fp8_dtype)
|
||||
|
||||
num_head, num_seq, head_size = x.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
||||
@@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
||||
|
||||
|
||||
def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
def per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-12,
|
||||
dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
dtype: torch.dtype = fp8_dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function quantizes input values to float8 values with per-token-group-quantization
|
||||
@@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
"""
|
||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
|
||||
b, m, k = x.shape
|
||||
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
||||
num_tiles_k = k // group_size
|
||||
@@ -1043,10 +1062,9 @@ def scaled_fp8_quant(
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
|
||||
@@ -14,6 +14,9 @@ except ImportError:
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
fp8_dtype,
|
||||
fp8_max,
|
||||
is_fp8_fnuz,
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_quant_fp8,
|
||||
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||
|
||||
if _is_hip and use_aiter_moe:
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
if _is_cuda:
|
||||
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||
try:
|
||||
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
||||
except ValueError:
|
||||
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
||||
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
# The condition is determined once as the operations
|
||||
# are time consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (
|
||||
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
||||
)
|
||||
def use_rowwise_torch_scaled_mm():
|
||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||
try:
|
||||
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
||||
except ValueError:
|
||||
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
||||
if _is_hip:
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
# The condition is determined once as the operations
|
||||
# are time consuming.
|
||||
return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
||||
return False
|
||||
|
||||
|
||||
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
|
||||
|
||||
|
||||
def cutlass_fp8_supported():
|
||||
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
output = fp8_blockwise_scaled_mm(
|
||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||
)
|
||||
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
elif _is_hip and use_aiter_moe:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
scale = fp8_max / amax
|
||||
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
|
||||
if _is_fp8_fnuz:
|
||||
dtype = fp8_dtype
|
||||
fp_max = fp8_max
|
||||
else:
|
||||
finfo = torch.finfo(dtype)
|
||||
fp_max = finfo.max
|
||||
|
||||
scale = fp_max / amax
|
||||
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
|
||||
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
if is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
if is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
|
||||
@@ -14,11 +14,6 @@ if not _is_cuda:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
def is_fp8_fnuz() -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
|
||||
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
fp8_dtype,
|
||||
is_fp8_fnuz,
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
|
||||
class W8A8Fp8Config(QuantizationConfig):
|
||||
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quantization_config.is_checkpoint_fp8_serialized:
|
||||
weight_scale = layer.weight_scale.detach()
|
||||
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
||||
if _is_hip:
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
else:
|
||||
# if cutlass not supported, we fall back to use torch._scaled_mm
|
||||
# which requires per tensor quantization on weight
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
||||
|
||||
# Update the layer with the new values.
|
||||
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
|
||||
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize
|
||||
else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
from sglang.srt.utils import BumpAllocator, add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
|
||||
)
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
attn_output.transpose(0, 1)
|
||||
)
|
||||
)
|
||||
attn_bmm_output = attn_output.new_empty(
|
||||
|
||||
@@ -7,9 +7,9 @@ import torch
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
|
||||
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
out, scale, _, _, _ = per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
x, group_size
|
||||
)
|
||||
out = out[:, :num_tokens, :]
|
||||
|
||||
Reference in New Issue
Block a user