Change bf16 to fp8 for some gemms in attention for DeepSeek ckpt v2 (#11805)
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||
from sglang.srt.utils import is_sm100_supported, offloader
|
||||
from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -441,25 +441,55 @@ def _requant_weight_ue8m0(
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
out_w, out_s = quant_weight_ue8m0(
|
||||
weight_dequant=weight_dequant,
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2])
|
||||
|
||||
return out_w, out_s
|
||||
|
||||
|
||||
def quant_weight_ue8m0(
|
||||
weight_dequant: torch.Tensor,
|
||||
weight_block_size: List[int],
|
||||
):
|
||||
assert weight_block_size == [128, 128]
|
||||
assert (
|
||||
weight_dequant.dtype == torch.bfloat16
|
||||
), f"{weight_dequant.dtype=} {weight_dequant.shape=}"
|
||||
|
||||
*batch_dims, n, k = weight_dequant.shape
|
||||
|
||||
weight_dequant_flat = weight_dequant.view((-1, k))
|
||||
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
|
||||
|
||||
out_w = out_w_flat.view(weight.shape)
|
||||
out_s = out_s_flat.view(weight_scale_inv.shape)
|
||||
|
||||
# NOTE copy and modified from DeepGEMM
|
||||
def _transform_scale(sf, mn: int):
|
||||
import deep_gemm.utils.layout
|
||||
|
||||
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
||||
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
||||
return sf
|
||||
|
||||
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
||||
out_w = out_w_flat.view((*batch_dims, n, k))
|
||||
out_s = out_s_flat.view(
|
||||
(
|
||||
*batch_dims,
|
||||
ceil_div(n, weight_block_size[0]),
|
||||
ceil_div(k, weight_block_size[1]),
|
||||
)
|
||||
)
|
||||
|
||||
return out_w, out_s
|
||||
|
||||
|
||||
def transform_scale_ue8m0_inplace(param, mn):
|
||||
param.data = _transform_scale_ue8m0(param.data, mn=mn)
|
||||
|
||||
|
||||
# NOTE copy and modified from DeepGEMM
|
||||
def _transform_scale_ue8m0(sf, mn):
|
||||
import deep_gemm.utils.layout
|
||||
|
||||
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
||||
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
||||
return sf
|
||||
|
||||
|
||||
# COPIED FROM DeepGEMM
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
|
||||
@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
channel_quant_to_tensor_quant,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
quant_weight_ue8m0,
|
||||
requant_weight_ue8m0_inplace,
|
||||
transform_scale_ue8m0_inplace,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
quant_config=self._get_q_b_proj_quant_config(quant_config),
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _get_q_b_proj_quant_config(quant_config):
|
||||
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
|
||||
# refer to real DeepSeek V3 quant config
|
||||
return Fp8Config(
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
weight_block_size=[128, 128],
|
||||
)
|
||||
else:
|
||||
return quant_config
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
):
|
||||
self._weight_requant_ue8m0(is_nextn)
|
||||
|
||||
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
|
||||
):
|
||||
self._transform_scale_ue8m0(is_nextn)
|
||||
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
|
||||
self._transform_scale_nextn_moe_ue8m0()
|
||||
|
||||
@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
|
||||
def _transform_scale_ue8m0(self, is_nextn=False):
|
||||
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
||||
|
||||
for layer_id in range(num_hidden_layers):
|
||||
if is_nextn:
|
||||
layer = self.model.decoder
|
||||
else:
|
||||
layer = self.model.layers[layer_id]
|
||||
|
||||
module_list = []
|
||||
if self.config.q_lora_rank is not None:
|
||||
module_list.append(layer.self_attn.q_b_proj)
|
||||
|
||||
for module in module_list:
|
||||
transform_scale_ue8m0_inplace(
|
||||
module.weight_scale_inv, mn=module.weight.shape[-2]
|
||||
)
|
||||
|
||||
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
|
||||
def _transform_scale_nextn_moe_ue8m0(self):
|
||||
layer = self.model.decoder
|
||||
@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
else:
|
||||
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||
|
||||
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
|
||||
weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
|
||||
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
|
||||
weights = self._quant_nextn_moe_to_fp8_ue8m0(
|
||||
weights, nextn_layer_id=nextn_layer_id
|
||||
@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
||||
|
||||
def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
|
||||
weights_dict = dict(weights)
|
||||
|
||||
# temporarily only support DeepSeek V3/R1
|
||||
weight_block_size = [128, 128]
|
||||
|
||||
for layer_id in trange(
|
||||
self.config.num_hidden_layers + int(is_nextn),
|
||||
desc="quant attn to fp8 ue8m0",
|
||||
):
|
||||
for stem in [
|
||||
# may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
|
||||
"q_b_proj",
|
||||
]:
|
||||
partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
|
||||
original_weight = weights_dict[f"{partial_name}.weight"]
|
||||
out_w, out_s = quant_weight_ue8m0(
|
||||
original_weight, weight_block_size=weight_block_size
|
||||
)
|
||||
weights_dict[f"{partial_name}.weight"] = out_w
|
||||
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
|
||||
|
||||
return list(weights_dict.items())
|
||||
|
||||
# TODO avoid code dup
|
||||
def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
|
||||
weights_dict = dict(weights)
|
||||
|
||||
Reference in New Issue
Block a user