diff --git a/python/sglang/math_utils.py b/python/sglang/math_utils.py new file mode 100644 index 000000000..fd0aa77f7 --- /dev/null +++ b/python/sglang/math_utils.py @@ -0,0 +1,8 @@ +# COPIED FROM DeepGEMM +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +# COPIED FROM DeepGEMM +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index e742f19c3..2408af197 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Tuple import torch +from sglang.math_utils import align from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.utils import is_sm100_supported @@ -390,6 +391,66 @@ def block_quant_dequant( return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype) +def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size): + assert isinstance(weight, torch.nn.Parameter) + assert isinstance(weight_scale_inv, torch.nn.Parameter) + weight.data, weight_scale_inv.data = _requant_weight_ue8m0( + weight, weight_scale_inv, weight_block_size + ) + + +def _requant_weight_ue8m0( + weight: torch.Tensor, + weight_scale_inv: torch.Tensor, + weight_block_size: List[int], +): + assert weight_block_size == [128, 128] + + *_, n, k = weight.shape + + weight_dequant = block_quant_dequant( + weight, + weight_scale_inv, + weight_block_size, + torch.bfloat16, + ) + + weight_dequant_flat = weight_dequant.view((-1, k)) + out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat) + + out_w = out_w_flat.view(weight.shape) + out_s = out_s_flat.view(weight_scale_inv.shape) + + # NOTE copy and modified from DeepGEMM + def _transform_scale(sf, mn: int): + import deep_gemm.utils.layout + + sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) + sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf) + return sf + + out_s = _transform_scale(out_s, mn=out_w.shape[-2]) + + return out_w, out_s + + +# COPIED FROM DeepGEMM +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2) + ) + + # COPIED FROM DeepGEMM def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c5c05b016..82a0c1d91 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, channel_quant_to_tensor_quant, normalize_e4m3fn_to_e4m3fnuz, + requant_weight_ue8m0_inplace, ) from sglang.srt.layers.quantization.int8_utils import ( block_dequant as int8_block_dequant, @@ -1935,6 +1936,61 @@ class DeepseekV2ForCausalLM(nn.Module): self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) self_attn.use_deep_gemm_bmm = True + if False: # TODO (pr-chain) + self._weight_requant_ue8m0() + + def _weight_requant_ue8m0(self): + weight_block_size = self.quant_config.weight_block_size + + moe_layers = list( + range( + self.config.first_k_dense_replace, + self.config.num_hidden_layers, + self.config.moe_layer_freq, + ) + ) + + for layer_id in range(self.config.num_hidden_layers): + layer = self.model.layers[layer_id] + + for module in [ + layer.self_attn.fused_qkv_a_proj_with_mqa, + layer.self_attn.q_b_proj, + layer.self_attn.kv_b_proj, + layer.self_attn.o_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + + if layer_id in moe_layers: + shared_experts = layer.mlp.shared_experts + for module in [ + shared_experts.gate_up_proj, + shared_experts.down_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + + experts = layer.mlp.experts + if isinstance(experts, DeepEPMoE): + for w in [ + experts.w13_weight_fp8, + experts.w2_weight_fp8, + ]: + requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size) + else: + mlp = layer.mlp + assert isinstance(mlp, DeepseekV2MLP) + for module in [ + mlp.gate_up_proj, + mlp.down_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): if is_nextn: