[refactor] slightly tidy fp8 module (#5993)
This commit is contained in:
@@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -654,10 +654,7 @@ def grouped_gemm_triton(
|
|||||||
if block_shape is not None:
|
if block_shape is not None:
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
if _is_cuda:
|
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
||||||
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
|
|
||||||
else:
|
|
||||||
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
|
||||||
|
|
||||||
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
||||||
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
||||||
|
|||||||
@@ -10,16 +10,14 @@ import torch
|
|||||||
from compressed_tensors import CompressionFormat
|
from compressed_tensors import CompressionFormat
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
is_cuda,
|
|
||||||
is_fp8_fnuz,
|
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
replace_parameter,
|
replace_parameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
|
|||||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
fp8_dtype,
|
||||||
|
is_fp8_fnuz,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
)
|
)
|
||||||
@@ -71,6 +73,11 @@ from sglang.srt.utils import (
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
|
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||||
|
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter import ActivationType, QuantType
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
||||||
@@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if _is_hip:
|
if _is_fp8_fnuz:
|
||||||
# activation_scheme: dynamic
|
# activation_scheme: dynamic
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale_inv,
|
weight_scale=layer.weight_scale_inv,
|
||||||
input_scale=None,
|
input_scale=None,
|
||||||
)
|
)
|
||||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
layer.weight_scale_inv = torch.nn.Parameter(
|
|
||||||
weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
else:
|
else:
|
||||||
layer.weight = torch.nn.Parameter(
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||||
layer.weight.data, requires_grad=False
|
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
)
|
layer.weight_scale_inv = torch.nn.Parameter(
|
||||||
layer.weight_scale_inv = torch.nn.Parameter(
|
weight_scale, requires_grad=False
|
||||||
layer.weight_scale_inv.data, requires_grad=False
|
)
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||||
@@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
weight_scale = layer.weight_scale
|
weight_scale = layer.weight_scale
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if _is_hip:
|
if _is_fp8_fnuz:
|
||||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@@ -482,11 +485,7 @@ class Fp8MoEMethod:
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = (
|
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
|
||||||
torch.uint32
|
|
||||||
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
|
||||||
else torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
block_n, block_k = (
|
block_n, block_k = (
|
||||||
@@ -511,7 +510,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
if _is_hip and use_hip_int4:
|
||||||
# INT4 MoE weight - INT32 packed
|
# INT4 MoE weight - INT32 packed
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -583,9 +582,7 @@ class Fp8MoEMethod:
|
|||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
if (
|
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
|
||||||
_is_hip
|
|
||||||
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
|
||||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||||
w13_weight_scale1 = torch.nn.Parameter(
|
w13_weight_scale1 = torch.nn.Parameter(
|
||||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||||
@@ -612,7 +609,7 @@ class Fp8MoEMethod:
|
|||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
if _is_hip and use_hip_int4:
|
||||||
extra_weight_attrs.update(
|
extra_weight_attrs.update(
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
)
|
)
|
||||||
@@ -644,14 +641,14 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
if _is_hip and use_hip_int4:
|
||||||
self.process_weights_hip_int4(layer)
|
self.process_weights_hip_int4(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if _is_hip:
|
if _is_fp8_fnuz:
|
||||||
# activation_scheme: dynamic
|
# activation_scheme: dynamic
|
||||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=layer.w13_weight,
|
weight=layer.w13_weight,
|
||||||
@@ -675,20 +672,19 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
if get_bool_env_var("SGLANG_AITER_MOE"):
|
if _is_hip and use_aiter_moe:
|
||||||
# Pre-shuffle weights
|
# Pre-shuffle weights
|
||||||
layer.w13_weight.data = shuffle_weight(
|
layer.w13_weight.data = shuffle_weight(
|
||||||
layer.w13_weight.contiguous(), (16, 16)
|
layer.w13_weight.contiguous(), (16, 16)
|
||||||
)
|
)
|
||||||
layer.w2_weight.data = shuffle_weight(
|
layer.w2_weight.data = shuffle_weight(
|
||||||
layer.w2_weight.contiguous(), (16, 16)
|
layer.w2_weight.contiguous(), (16, 16)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
|
|
||||||
@@ -742,7 +738,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if _is_hip:
|
if _is_fp8_fnuz:
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@@ -798,7 +794,7 @@ class Fp8MoEMethod:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def process_weights_hip_int4(self, layer: Module):
|
def process_weights_hip_int4(self, layer: Module):
|
||||||
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
|
# TODO: and use_aiter_moe: add after triton kernel added
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||||
# Weight Permutation
|
# Weight Permutation
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
@@ -845,7 +841,7 @@ class Fp8MoEMethod:
|
|||||||
padding_size, # Avoid circular import
|
padding_size, # Avoid circular import
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_bool_env_var("SGLANG_AITER_MOE"):
|
if use_aiter_moe:
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -856,7 +852,7 @@ class Fp8MoEMethod:
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
|
# ROCm (use_aiter_moe): using column-wise scaling
|
||||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||||
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
||||||
@@ -908,59 +904,16 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
ret = self.maybe_apply_hip_fused_experts(
|
||||||
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
|
layer,
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
x,
|
||||||
return ck_moe_2stages(
|
topk_weights,
|
||||||
x,
|
topk_ids,
|
||||||
layer.w13_weight,
|
activation,
|
||||||
layer.w2_weight,
|
no_combine,
|
||||||
topk_weights,
|
)
|
||||||
topk_ids,
|
if ret is not None:
|
||||||
QuantType.per_Token,
|
return ret
|
||||||
layer.w13_weight_scale1,
|
|
||||||
layer.w2_weight_scale1,
|
|
||||||
activation=(
|
|
||||||
ActivationType.Silu
|
|
||||||
if activation == "silu"
|
|
||||||
else ActivationType.Gelu
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if get_bool_env_var("SGLANG_AITER_MOE"):
|
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
|
||||||
if self.block_quant:
|
|
||||||
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
|
||||||
assert (
|
|
||||||
activation == "silu"
|
|
||||||
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
|
||||||
return asm_moe(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
layer.w13_weight_scale_inv,
|
|
||||||
layer.w2_weight_scale_inv,
|
|
||||||
block_shape=tuple(self.quant_config.weight_block_size),
|
|
||||||
expert_mask=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ck_moe_2stages(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
QuantType.per_Token,
|
|
||||||
layer.w13_weight_scale1,
|
|
||||||
layer.w2_weight_scale1,
|
|
||||||
activation=(
|
|
||||||
ActivationType.Silu
|
|
||||||
if activation == "silu"
|
|
||||||
else ActivationType.Gelu
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -987,6 +940,68 @@ class Fp8MoEMethod:
|
|||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def maybe_apply_hip_fused_experts(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
|
no_combine: bool = False,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
if use_hip_int4:
|
||||||
|
# TODO: add triton kernel and add check use_aiter_moe
|
||||||
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
|
return ck_moe_2stages(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
QuantType.per_Token,
|
||||||
|
layer.w13_weight_scale1,
|
||||||
|
layer.w2_weight_scale1,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_aiter_moe:
|
||||||
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
|
if self.block_quant:
|
||||||
|
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
|
||||||
|
assert (
|
||||||
|
activation == "silu"
|
||||||
|
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
|
||||||
|
return asm_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
layer.w13_weight_scale_inv,
|
||||||
|
layer.w2_weight_scale_inv,
|
||||||
|
block_shape=tuple(self.quant_config.weight_block_size),
|
||||||
|
expert_mask=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ck_moe_2stages(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
QuantType.per_Token,
|
||||||
|
layer.w13_weight_scale1,
|
||||||
|
layer.w2_weight_scale1,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu
|
||||||
|
if activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -34,12 +35,6 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
if _is_hip:
|
|
||||||
fp8_max = 224.0
|
|
||||||
else:
|
|
||||||
fp8_max = torch.finfo(_fp8_type).max
|
|
||||||
fp8_min = -fp8_max
|
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -54,6 +49,24 @@ if _is_cuda:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def is_fp8_fnuz() -> bool:
|
||||||
|
if _is_hip:
|
||||||
|
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||||
|
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if is_fp8_fnuz():
|
||||||
|
fp8_dtype = torch.float8_e4m3fnuz
|
||||||
|
fp8_max = 224.0
|
||||||
|
else:
|
||||||
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
fp8_max = torch.finfo(fp8_dtype).max
|
||||||
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
|
|
||||||
def deep_gemm_fp8_fp8_bf16_nt(
|
def deep_gemm_fp8_fp8_bf16_nt(
|
||||||
@@ -198,7 +211,7 @@ def per_token_group_quant_fp8(
|
|||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
if column_major_scales:
|
if column_major_scales:
|
||||||
@@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||||
if column_major_scales:
|
if column_major_scales:
|
||||||
if scale_tma_aligned:
|
if scale_tma_aligned:
|
||||||
# aligned to 4 * sizeof(float)
|
# aligned to 4 * sizeof(float)
|
||||||
@@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
|
|
||||||
def sglang_per_token_quant_fp8(
|
def sglang_per_token_quant_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
dtype: torch.dtype = _fp8_type,
|
dtype: torch.dtype = fp8_dtype,
|
||||||
):
|
):
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
@@ -384,7 +397,7 @@ def static_quant_fp8(
|
|||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
||||||
M = x.numel() // x.shape[-1]
|
M = x.numel() // x.shape[-1]
|
||||||
N = x.shape[-1]
|
N = x.shape[-1]
|
||||||
if repeat_scale:
|
if repeat_scale:
|
||||||
@@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
|
||||||
|
return _w8a8_block_fp8_matmul
|
||||||
|
|
||||||
|
|
||||||
|
if _is_hip:
|
||||||
|
|
||||||
|
def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
|
||||||
|
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||||
|
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||||
|
# compute units available on the device.
|
||||||
|
num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||||
|
N, META["BLOCK_SIZE_N"]
|
||||||
|
)
|
||||||
|
num_workgroups <= get_device_core_count()
|
||||||
|
|
||||||
|
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
|
||||||
|
if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
|
||||||
|
return _w8a8_block_fp8_matmul_unrolledx4
|
||||||
|
else:
|
||||||
|
return _w8a8_block_fp8_matmul
|
||||||
|
|
||||||
|
|
||||||
def w8a8_block_fp8_matmul(
|
def w8a8_block_fp8_matmul(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul(
|
|||||||
C_shape = A.shape[:-1] + (N,)
|
C_shape = A.shape[:-1] + (N,)
|
||||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
|
||||||
if configs:
|
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Default config
|
|
||||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": block_size[0],
|
|
||||||
"BLOCK_SIZE_K": block_size[1],
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
def grid(META):
|
|
||||||
return (
|
|
||||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
|
||||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
|
||||||
# compute units available on the device.
|
|
||||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
|
||||||
N, config["BLOCK_SIZE_N"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# deepgemm only support bf16
|
# deepgemm only support bf16
|
||||||
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
@@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul(
|
|||||||
else:
|
else:
|
||||||
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||||
else:
|
else:
|
||||||
kernel = (
|
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
if configs:
|
||||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
# If an optimal configuration map has been found, look up the
|
||||||
else _w8a8_block_fp8_matmul
|
# optimal config
|
||||||
)
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
|
else:
|
||||||
|
# Default config
|
||||||
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": block_size[0],
|
||||||
|
"BLOCK_SIZE_K": block_size[1],
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
|
||||||
|
|
||||||
kernel[grid](
|
kernel[grid](
|
||||||
A,
|
A,
|
||||||
@@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8(
|
|||||||
and x_s_out.device == x.device
|
and x_s_out.device == x.device
|
||||||
)
|
)
|
||||||
|
|
||||||
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
x_q = x.new_empty(x.size(), dtype=fp8_dtype)
|
||||||
|
|
||||||
num_head, num_seq, head_size = x.shape
|
num_head, num_seq, head_size = x.shape
|
||||||
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
||||||
@@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
|||||||
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
def per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int = 128,
|
group_size: int = 128,
|
||||||
eps: float = 1e-12,
|
eps: float = 1e-12,
|
||||||
dtype: torch.dtype = torch.float8_e4m3fn,
|
dtype: torch.dtype = fp8_dtype,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This function quantizes input values to float8 values with per-token-group-quantization
|
This function quantizes input values to float8 values with per-token-group-quantization
|
||||||
@@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
|||||||
"""
|
"""
|
||||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
fp8_max = finfo.max
|
|
||||||
if _is_hip:
|
|
||||||
dtype = torch.float8_e4m3fnuz
|
|
||||||
fp8_max = 224.0
|
|
||||||
|
|
||||||
b, m, k = x.shape
|
b, m, k = x.shape
|
||||||
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
||||||
num_tiles_k = k // group_size
|
num_tiles_k = k // group_size
|
||||||
@@ -1043,10 +1062,9 @@ def scaled_fp8_quant(
|
|||||||
"""
|
"""
|
||||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||||
shape = input.shape
|
shape = input.shape
|
||||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
if num_token_padding:
|
if num_token_padding:
|
||||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
# Dynamic scaling
|
# Dynamic scaling
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ except ImportError:
|
|||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
fp8_dtype,
|
||||||
|
fp8_max,
|
||||||
|
is_fp8_fnuz,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
sglang_per_token_quant_fp8,
|
sglang_per_token_quant_fp8,
|
||||||
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||||
|
|
||||||
|
if _is_hip and use_aiter_moe:
|
||||||
from aiter import gemm_a8w8_blockscale
|
from aiter import gemm_a8w8_blockscale
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
|
|||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = None
|
TORCH_DEVICE_IDENTITY = None
|
||||||
|
|
||||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
|
||||||
try:
|
|
||||||
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
|
||||||
except ValueError:
|
|
||||||
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
|
||||||
|
|
||||||
# The condition to determine if it is on a platform that supports
|
def use_rowwise_torch_scaled_mm():
|
||||||
# torch._scaled_mm rowwise feature.
|
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||||
# The condition is determined once as the operations
|
try:
|
||||||
# are time consuming.
|
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
||||||
USE_ROWWISE_TORCH_SCALED_MM = (
|
except ValueError:
|
||||||
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
||||||
)
|
if _is_hip:
|
||||||
|
# The condition to determine if it is on a platform that supports
|
||||||
|
# torch._scaled_mm rowwise feature.
|
||||||
|
# The condition is determined once as the operations
|
||||||
|
# are time consuming.
|
||||||
|
return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_supported():
|
def cutlass_fp8_supported():
|
||||||
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
output = fp8_blockwise_scaled_mm(
|
output = fp8_blockwise_scaled_mm(
|
||||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||||
)
|
)
|
||||||
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
elif _is_hip and use_aiter_moe:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=False
|
||||||
)
|
)
|
||||||
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
|
|
||||||
|
|
||||||
def input_to_float8(
|
def input_to_float8(
|
||||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
min_val, max_val = x.aminmax()
|
min_val, max_val = x.aminmax()
|
||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
||||||
fp8_max = finfo.max
|
|
||||||
if _is_hip:
|
if _is_fp8_fnuz:
|
||||||
dtype = torch.float8_e4m3fnuz
|
dtype = fp8_dtype
|
||||||
fp8_max = 224.0
|
fp_max = fp8_max
|
||||||
scale = fp8_max / amax
|
else:
|
||||||
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
finfo = torch.finfo(dtype)
|
||||||
|
fp_max = finfo.max
|
||||||
|
|
||||||
|
scale = fp_max / amax
|
||||||
|
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
|
||||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import is_hip
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_fp8_fnuz(cls) -> bool:
|
|
||||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
|
||||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||||
|
|
||||||
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
# We prefer to use separate k_scale and v_scale if present
|
# We prefer to use separate k_scale and v_scale if present
|
||||||
k_scale = layer.k_scale.to("cpu").tolist()
|
k_scale = layer.k_scale.to("cpu").tolist()
|
||||||
v_scale = layer.v_scale.to("cpu").tolist()
|
v_scale = layer.v_scale.to("cpu").tolist()
|
||||||
if _is_hip and self.is_fp8_fnuz():
|
if is_fp8_fnuz():
|
||||||
k_scale *= 2
|
k_scale *= 2
|
||||||
v_scale *= 2
|
v_scale *= 2
|
||||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||||
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
if _is_hip and self.is_fp8_fnuz():
|
if is_fp8_fnuz():
|
||||||
k_scale *= 2
|
k_scale *= 2
|
||||||
v_scale *= 2
|
v_scale *= 2
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,6 @@ if not _is_cuda:
|
|||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
|
||||||
def is_fp8_fnuz() -> bool:
|
|
||||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
|
||||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped(
|
def is_layer_skipped(
|
||||||
prefix: str,
|
prefix: str,
|
||||||
ignored_layers: List[str],
|
ignored_layers: List[str],
|
||||||
|
|||||||
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
fp8_dtype,
|
||||||
|
is_fp8_fnuz,
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
input_to_float8,
|
input_to_float8,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip, set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
|
|
||||||
class W8A8Fp8Config(QuantizationConfig):
|
class W8A8Fp8Config(QuantizationConfig):
|
||||||
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|||||||
if self.quantization_config.is_checkpoint_fp8_serialized:
|
if self.quantization_config.is_checkpoint_fp8_serialized:
|
||||||
weight_scale = layer.weight_scale.detach()
|
weight_scale = layer.weight_scale.detach()
|
||||||
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
||||||
if _is_hip:
|
if _is_fp8_fnuz:
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight, weight_scale=weight_scale
|
weight=weight, weight_scale=weight_scale
|
||||||
)
|
)
|
||||||
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight, layer.weight.shape[-1]
|
layer.weight, layer.weight.shape[-1]
|
||||||
)
|
)
|
||||||
weight_scale = weight_scale.t().contiguous()
|
weight_scale = weight_scale.t().contiguous()
|
||||||
if _is_hip:
|
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=weight, weight_scale=weight_scale
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# if cutlass not supported, we fall back to use torch._scaled_mm
|
# if cutlass not supported, we fall back to use torch._scaled_mm
|
||||||
# which requires per tensor quantization on weight
|
# which requires per tensor quantization on weight
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
||||||
|
|
||||||
# Update the layer with the new values.
|
# Update the layer with the new values.
|
||||||
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
|
|||||||
):
|
):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
|||||||
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import ReplicatedLinear
|
from sglang.srt.layers.linear import ReplicatedLinear
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
|
||||||
block_quant_to_tensor_quant,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.quantization.int8_utils import (
|
|
||||||
block_dequant as int8_block_dequant,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||||
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
|
from sglang.srt.utils import BumpAllocator, add_prefix
|
||||||
|
|
||||||
_is_hip = is_hip()
|
|
||||||
_is_cuda = is_cuda()
|
|
||||||
|
|
||||||
if _is_cuda:
|
|
||||||
from sgl_kernel import awq_dequantize
|
|
||||||
else:
|
|
||||||
from vllm._custom_ops import awq_dequantize
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
|
||||||
per_tensor_quant_mla_fp8,
|
per_tensor_quant_mla_fp8,
|
||||||
|
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
block_quant_to_tensor_quant,
|
block_quant_to_tensor_quant,
|
||||||
@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
if self.use_deep_gemm_bmm:
|
if self.use_deep_gemm_bmm:
|
||||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
|
||||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
q_nope_out = q_nope.new_empty(
|
q_nope_out = q_nope.new_empty(
|
||||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||||
@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
if self.use_deep_gemm_bmm:
|
if self.use_deep_gemm_bmm:
|
||||||
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
||||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
attn_output.transpose(0, 1)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
attn_bmm_output = attn_output.new_empty(
|
attn_bmm_output = attn_output.new_empty(
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import torch
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
|
||||||
per_tensor_quant_mla_fp8,
|
per_tensor_quant_mla_fp8,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
|
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||||
static_quant_fp8,
|
static_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
|
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
|
||||||
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
|
out, scale, _, _, _ = per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||||
x, group_size
|
x, group_size
|
||||||
)
|
)
|
||||||
out = out[:, :num_tokens, :]
|
out = out[:, :num_tokens, :]
|
||||||
|
|||||||
Reference in New Issue
Block a user