Re-quantize DeepSeek model weights to support DeepGEMM new input format (#7156)
This commit is contained in:
8
python/sglang/math_utils.py
Normal file
8
python/sglang/math_utils.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# COPIED FROM DeepGEMM
|
||||
def align(x: int, y: int) -> int:
|
||||
return ceil_div(x, y) * y
|
||||
|
||||
|
||||
# COPIED FROM DeepGEMM
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.math_utils import align
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
|
||||
@@ -390,6 +391,66 @@ def block_quant_dequant(
|
||||
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
||||
|
||||
|
||||
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
||||
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
|
||||
weight, weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
|
||||
def _requant_weight_ue8m0(
|
||||
weight: torch.Tensor,
|
||||
weight_scale_inv: torch.Tensor,
|
||||
weight_block_size: List[int],
|
||||
):
|
||||
assert weight_block_size == [128, 128]
|
||||
|
||||
*_, n, k = weight.shape
|
||||
|
||||
weight_dequant = block_quant_dequant(
|
||||
weight,
|
||||
weight_scale_inv,
|
||||
weight_block_size,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
weight_dequant_flat = weight_dequant.view((-1, k))
|
||||
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
|
||||
|
||||
out_w = out_w_flat.view(weight.shape)
|
||||
out_s = out_s_flat.view(weight_scale_inv.shape)
|
||||
|
||||
# NOTE copy and modified from DeepGEMM
|
||||
def _transform_scale(sf, mn: int):
|
||||
import deep_gemm.utils.layout
|
||||
|
||||
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
||||
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
|
||||
return sf
|
||||
|
||||
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
||||
|
||||
return out_w, out_s
|
||||
|
||||
|
||||
# COPIED FROM DeepGEMM
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
# COPIED FROM DeepGEMM
|
||||
def ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
channel_quant_to_tensor_quant,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requant_weight_ue8m0_inplace,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
@@ -1935,6 +1936,61 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
if False: # TODO (pr-chain)
|
||||
self._weight_requant_ue8m0()
|
||||
|
||||
def _weight_requant_ue8m0(self):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
|
||||
moe_layers = list(
|
||||
range(
|
||||
self.config.first_k_dense_replace,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.moe_layer_freq,
|
||||
)
|
||||
)
|
||||
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
layer = self.model.layers[layer_id]
|
||||
|
||||
for module in [
|
||||
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
||||
layer.self_attn.q_b_proj,
|
||||
layer.self_attn.kv_b_proj,
|
||||
layer.self_attn.o_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
if layer_id in moe_layers:
|
||||
shared_experts = layer.mlp.shared_experts
|
||||
for module in [
|
||||
shared_experts.gate_up_proj,
|
||||
shared_experts.down_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
experts = layer.mlp.experts
|
||||
if isinstance(experts, DeepEPMoE):
|
||||
for w in [
|
||||
experts.w13_weight_fp8,
|
||||
experts.w2_weight_fp8,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
|
||||
else:
|
||||
mlp = layer.mlp
|
||||
assert isinstance(mlp, DeepseekV2MLP)
|
||||
for module in [
|
||||
mlp.gate_up_proj,
|
||||
mlp.down_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||
|
||||
if is_nextn:
|
||||
|
||||
Reference in New Issue
Block a user