AMD: set weights and scaling numbers properly for block FP8 (#2637)
This commit is contained in:
@@ -272,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if is_hip():
|
||||||
|
# activation_scheme: dynamic
|
||||||
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
layer.weight = torch.nn.Parameter(weight, require_grad=False)
|
||||||
|
layer.weight_scale_inv = torch.nn.Parameter(
|
||||||
|
weight_scale, require_grad=False
|
||||||
|
)
|
||||||
|
layer.input_scale = None
|
||||||
return
|
return
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
@@ -369,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
block_size=self.quant_config.weight_block_size,
|
block_size=self.quant_config.weight_block_size,
|
||||||
weight_scale=layer.weight_scale_inv,
|
weight_scale=layer.weight_scale_inv,
|
||||||
input_scale=layer.input_scale,
|
input_scale=None,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -553,6 +566,30 @@ class Fp8MoEMethod:
|
|||||||
|
|
||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if is_hip():
|
||||||
|
# activation_scheme: dynamic
|
||||||
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.w13_weight,
|
||||||
|
weight_scale=layer.w13_weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.w2_weight,
|
||||||
|
weight_scale=layer.w2_weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w13_weight_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w2_weight_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_input_scale = None
|
||||||
return
|
return
|
||||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
|||||||
@@ -22,7 +22,10 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import get_device_name
|
from sglang.srt.utils import get_device_name, is_hip
|
||||||
|
|
||||||
|
is_hip_ = is_hip()
|
||||||
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -73,7 +76,7 @@ def per_token_group_quant_fp8(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: torch.dtype = torch.float8_e4m3fn,
|
dtype: torch.dtype = fp8_type_,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
|
|
||||||
@@ -95,9 +98,13 @@ def per_token_group_quant_fp8(
|
|||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
finfo = torch.finfo(dtype)
|
||||||
fp8_min = finfo.min
|
|
||||||
fp8_max = finfo.max
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
if is_hip_:
|
||||||
|
fp8_max = 224.0
|
||||||
|
|
||||||
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@@ -63,8 +66,11 @@ def input_to_float8(
|
|||||||
finfo = torch.finfo(dtype)
|
finfo = torch.finfo(dtype)
|
||||||
min_val, max_val = x.aminmax()
|
min_val, max_val = x.aminmax()
|
||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
scale = finfo.max / amax
|
fp8_max = finfo.max
|
||||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
if is_hip_:
|
||||||
|
fp8_max = 224.0
|
||||||
|
scale = fp8_max / amax
|
||||||
|
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user