diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fa492277a..5252c2411 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -57,6 +57,7 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.fp8_kernel import ( + is_fp8_fnuz, per_tensor_quant_mla_fp8, per_token_group_quant_mla_deep_gemm_masked_fp8, ) @@ -101,6 +102,7 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() +_is_fp8_fnuz = is_fp8_fnuz() if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 @@ -684,7 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_kc = None self.w_vc = None - self.w_scale = None + self.w_scale = 1.0 self.w_scale_k = None self.w_scale_v = None @@ -948,8 +950,8 @@ class DeepseekV2AttentionMLA(nn.Module): expected_m, ) q_nope_out = q_nope_out[:, :expected_m, :] - elif self.w_kc.dtype == torch.float8_e4m3fnuz: - # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + elif _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, @@ -1000,8 +1002,8 @@ class DeepseekV2AttentionMLA(nn.Module): expected_m, ) attn_bmm_output = attn_bmm_output[:, :expected_m, :] - elif self.w_vc.dtype == torch.float8_e4m3fnuz: - # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + elif _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, @@ -1052,8 +1054,8 @@ class DeepseekV2AttentionMLA(nn.Module): latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if self.w_kc.dtype == torch.float8_e4m3fnuz: - # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + if _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, @@ -1186,8 +1188,8 @@ class DeepseekV2AttentionMLA(nn.Module): attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - if self.w_vc.dtype == torch.float8_e4m3fnuz: - # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + if _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, @@ -1749,46 +1751,56 @@ class DeepseekV2ForCausalLM(nn.Module): torch.float8_e4m3fn, torch.float8_e4m3fnuz, ): - if hasattr(self.quant_config, "weight_block_size"): + if ( + hasattr(self.quant_config, "weight_block_size") + and self.quant_config.weight_block_size is not None + ): weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv - if ( - _is_cuda - and weight_block_size[0] == 128 - and weight_block_size[1] == 128 - and model_dtype == torch.bfloat16 + if ( + _is_cuda + and weight_block_size[0] == 128 + and weight_block_size[1] == 128 + and model_dtype == torch.bfloat16 + ): + if _ENABLE_JIT_DEEPGEMM and get_bool_env_var( + "SGL_USE_DEEPGEMM_BMM", "false" ): - if _ENABLE_JIT_DEEPGEMM and get_bool_env_var( - "SGL_USE_DEEPGEMM_BMM", "false" - ): - block_scale = weight_scale - use_deep_gemm_bmm = True - else: - w = block_quant_dequant( - weight, - weight_scale, - weight_block_size, - model_dtype, - ) + block_scale = weight_scale + use_deep_gemm_bmm = True else: - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + model_dtype, ) - self_attn.w_scale = scale + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) self_attn.w_scale = scale