DeepSeek: enable none block-quant FP8 quantizations (#6638)
This commit is contained in:
@@ -57,6 +57,7 @@ from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||
)
|
||||
@@ -101,6 +102,7 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
@@ -684,7 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
self.w_scale = 1.0
|
||||
|
||||
self.w_scale_k = None
|
||||
self.w_scale_v = None
|
||||
@@ -948,8 +950,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
expected_m,
|
||||
)
|
||||
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
elif _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
@@ -1000,8 +1002,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
expected_m,
|
||||
)
|
||||
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
elif _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
@@ -1052,8 +1054,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
if _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
@@ -1186,8 +1188,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
if _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
@@ -1749,46 +1751,56 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
if hasattr(self.quant_config, "weight_block_size"):
|
||||
if (
|
||||
hasattr(self.quant_config, "weight_block_size")
|
||||
and self.quant_config.weight_block_size is not None
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
|
||||
if (
|
||||
_is_cuda
|
||||
and weight_block_size[0] == 128
|
||||
and weight_block_size[1] == 128
|
||||
and model_dtype == torch.bfloat16
|
||||
if (
|
||||
_is_cuda
|
||||
and weight_block_size[0] == 128
|
||||
and weight_block_size[1] == 128
|
||||
and model_dtype == torch.bfloat16
|
||||
):
|
||||
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
||||
"SGL_USE_DEEPGEMM_BMM", "false"
|
||||
):
|
||||
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
||||
"SGL_USE_DEEPGEMM_BMM", "false"
|
||||
):
|
||||
block_scale = weight_scale
|
||||
use_deep_gemm_bmm = True
|
||||
else:
|
||||
w = block_quant_dequant(
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_block_size,
|
||||
model_dtype,
|
||||
)
|
||||
block_scale = weight_scale
|
||||
use_deep_gemm_bmm = True
|
||||
else:
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
w = block_quant_dequant(
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_block_size,
|
||||
model_dtype,
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
else:
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale,
|
||||
input_scale=None,
|
||||
)
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale
|
||||
|
||||
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
||||
self_attn.w_scale = scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user