diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index dc70c53b3..f33139883 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -5,7 +5,7 @@ import torch from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil -from sglang.srt.utils import is_sm100_supported, offloader +from sglang.srt.utils import ceil_div, is_sm100_supported, offloader try: from vllm import _custom_ops as ops @@ -441,25 +441,55 @@ def _requant_weight_ue8m0( torch.bfloat16, ) + out_w, out_s = quant_weight_ue8m0( + weight_dequant=weight_dequant, + weight_block_size=weight_block_size, + ) + + out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2]) + + return out_w, out_s + + +def quant_weight_ue8m0( + weight_dequant: torch.Tensor, + weight_block_size: List[int], +): + assert weight_block_size == [128, 128] + assert ( + weight_dequant.dtype == torch.bfloat16 + ), f"{weight_dequant.dtype=} {weight_dequant.shape=}" + + *batch_dims, n, k = weight_dequant.shape + weight_dequant_flat = weight_dequant.view((-1, k)) out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat) - out_w = out_w_flat.view(weight.shape) - out_s = out_s_flat.view(weight_scale_inv.shape) - - # NOTE copy and modified from DeepGEMM - def _transform_scale(sf, mn: int): - import deep_gemm.utils.layout - - sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) - sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) - return sf - - out_s = _transform_scale(out_s, mn=out_w.shape[-2]) + out_w = out_w_flat.view((*batch_dims, n, k)) + out_s = out_s_flat.view( + ( + *batch_dims, + ceil_div(n, weight_block_size[0]), + ceil_div(k, weight_block_size[1]), + ) + ) return out_w, out_s +def transform_scale_ue8m0_inplace(param, mn): + param.data = _transform_scale_ue8m0(param.data, mn=mn) + + +# NOTE copy and modified from DeepGEMM +def _transform_scale_ue8m0(sf, mn): + import deep_gemm.utils.layout + + sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) + sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + return sf + + # COPIED FROM DeepGEMM def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e796809a0..31da13318 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( channel_quant_to_tensor_quant, input_to_float8, normalize_e4m3fn_to_e4m3fnuz, + quant_weight_ue8m0, requant_weight_ue8m0_inplace, + transform_scale_ue8m0_inplace, ) from sglang.srt.layers.quantization.int8_utils import ( block_dequant as int8_block_dequant, @@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module): q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, - quant_config=quant_config, + quant_config=self._get_q_b_proj_quant_config(quant_config), prefix=add_prefix("q_b_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, @@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module): output, _ = self.o_proj(attn_output) return output + @staticmethod + def _get_q_b_proj_quant_config(quant_config): + if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"): + # refer to real DeepSeek V3 quant config + return Fp8Config( + is_checkpoint_fp8_serialized=True, + weight_block_size=[128, 128], + ) + else: + return quant_config + class DeepseekV2DecoderLayer(nn.Module): @@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module): ): self._weight_requant_ue8m0(is_nextn) + # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN") + ): + self._transform_scale_ue8m0(is_nextn) if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): self._transform_scale_nextn_moe_ue8m0() @@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module): module.weight, module.weight_scale_inv, weight_block_size ) + # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading + def _transform_scale_ue8m0(self, is_nextn=False): + num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers + + for layer_id in range(num_hidden_layers): + if is_nextn: + layer = self.model.decoder + else: + layer = self.model.layers[layer_id] + + module_list = [] + if self.config.q_lora_rank is not None: + module_list.append(layer.self_attn.q_b_proj) + + for module in module_list: + transform_scale_ue8m0_inplace( + module.weight_scale_inv, mn=module.weight.shape[-2] + ) + # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0) def _transform_scale_nextn_moe_ue8m0(self): layer = self.model.decoder @@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module): else: raise ValueError("num_nextn_predict_layers is not in the config") + if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"): + weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn) if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): weights = self._quant_nextn_moe_to_fp8_ue8m0( weights, nextn_layer_id=nextn_layer_id @@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module): self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn): + weights_dict = dict(weights) + + # temporarily only support DeepSeek V3/R1 + weight_block_size = [128, 128] + + for layer_id in trange( + self.config.num_hidden_layers + int(is_nextn), + desc="quant attn to fp8 ue8m0", + ): + for stem in [ + # may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1 + "q_b_proj", + ]: + partial_name = f"model.layers.{layer_id}.self_attn.{stem}" + original_weight = weights_dict[f"{partial_name}.weight"] + out_w, out_s = quant_weight_ue8m0( + original_weight, weight_block_size=weight_block_size + ) + weights_dict[f"{partial_name}.weight"] = out_w + weights_dict[f"{partial_name}.weight_scale_inv"] = out_s + + return list(weights_dict.items()) + # TODO avoid code dup def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int): weights_dict = dict(weights)