From c5210dfa3802dbe08a8de9e860cea0c932307c9d Mon Sep 17 00:00:00 2001 From: HAI Date: Mon, 30 Dec 2024 05:31:12 -0800 Subject: [PATCH] AMD DeepSeek_V3 FP8 Numerical fix (#2667) --- python/sglang/srt/models/deepseek_v2.py | 41 ++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c56430ce0..a9c0b59ce 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -46,6 +46,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, input_to_float8, + normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import ( @@ -55,7 +56,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_flashinfer_available, is_hip + +is_hip_ = is_hip() if is_flashinfer_available(): from flashinfer import bmm_fp8 @@ -573,7 +576,13 @@ class DeepseekV2AttentionMLA(nn.Module): ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if self.w_kc.dtype == torch.float8_e4m3fn: + if self.w_kc.dtype == torch.float8_e4m3fnuz: + # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) + elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn ) @@ -598,7 +607,13 @@ class DeepseekV2AttentionMLA(nn.Module): attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - if self.w_vc.dtype == torch.float8_e4m3fn: + if self.w_vc.dtype == torch.float8_e4m3fnuz: + # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + attn_bmm_output = torch.bmm( + attn_output.to(torch.bfloat16).transpose(0, 1), + self.w_vc.to(torch.bfloat16) * self.w_scale, + ) + elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn ) @@ -940,15 +955,25 @@ class DeepseekV2ForCausalLM(nn.Module): w = self_attn.kv_b_proj.weight # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. # This may affect the accuracy of fp8 model. - if ( - hasattr(self.quant_config, "weight_block_size") - and w.dtype == torch.float8_e4m3fn + if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, ): weight_block_size = self.quant_config.weight_block_size if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if is_hip_: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w, scale = block_quant_to_tensor_quant( - w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size + weight, weight_scale, weight_block_size ) self_attn.w_scale = scale w_kc, w_vc = w.unflatten( @@ -961,6 +986,8 @@ class DeepseekV2ForCausalLM(nn.Module): and self_attn.w_scale is None ): self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if is_hip_: + self_attn.w_scale *= 2.0 class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):