Clean up fp8 support (#4230)
This commit is contained in:
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
@@ -624,56 +624,9 @@ class Fp8MoEMethod:
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||
# Weight Permutation
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
|
||||
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
|
||||
# We won't do requant each expert's fp8 weight (not direct available),
|
||||
# instead we adjust half of INT4 w13_weight_scale1 numbers
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.num_experts):
|
||||
start = 0
|
||||
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||
for shard_id in range(2):
|
||||
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
|
||||
int4_rescale = (
|
||||
layer.w13_weight_scale[expert_id][shard_id]
|
||||
/ max_w13_scale_fp8
|
||||
)
|
||||
layer.w13_weight_scale1[expert_id][
|
||||
start : start + shard_size
|
||||
] *= int4_rescale
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
||||
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
||||
for expert_id in range(layer.num_experts):
|
||||
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
||||
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||
self.process_weights_hip_int4(layer)
|
||||
return
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
padding_size, # Avoid circular import
|
||||
)
|
||||
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
@@ -710,6 +663,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_weight.contiguous(), (16, 16)
|
||||
)
|
||||
return
|
||||
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
@@ -736,32 +690,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
if is_hip_:
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
# ROCm (CK_MOE): using column-wise scaling
|
||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||
elif get_bool_env_var("MOE_PADDING"):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
self.process_weights_hip_scale_padding(layer)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
@@ -843,34 +772,84 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
if is_hip_:
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
# ROCm (CK_MOE): using column-wise scaling
|
||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||
elif get_bool_env_var("MOE_PADDING"):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
self.process_weights_hip_scale_padding(layer)
|
||||
return
|
||||
|
||||
def process_weights_hip_int4(self, layer: Module):
|
||||
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||
# Weight Permutation
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
|
||||
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
|
||||
# We won't do requant each expert's fp8 weight (not direct available),
|
||||
# instead we adjust half of INT4 w13_weight_scale1 numbers
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.num_experts):
|
||||
start = 0
|
||||
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||
for shard_id in range(2):
|
||||
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
|
||||
int4_rescale = (
|
||||
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
|
||||
)
|
||||
layer.w13_weight_scale1[expert_id][
|
||||
start : start + shard_size
|
||||
] *= int4_rescale
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
|
||||
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
||||
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
||||
for expert_id in range(layer.num_experts):
|
||||
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
||||
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||
|
||||
def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
padding_size, # Avoid circular import
|
||||
)
|
||||
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
# ROCm (CK_MOE): using column-wise scaling
|
||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||
elif get_bool_env_var("MOE_PADDING"):
|
||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Reference in New Issue
Block a user