diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2f06a5534..bb80e2da2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -19,6 +19,7 @@ limitations under the License. from typing import Any, Dict, Iterable, Optional, Tuple import torch +from flashinfer import bmm_fp8 from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig @@ -161,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def input_to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + class DeepseekV2Attention(nn.Module): def __init__( @@ -255,11 +265,6 @@ class DeepseekV2Attention(nn.Module): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # self.attn = Attention(self.num_heads, - # self.qk_head_dim, - # self.scaling, - # num_kv_heads=self.num_heads) - # TODO, support head_size 192 self.attn = RadixAttention( self.num_local_heads, @@ -283,7 +288,7 @@ class DeepseekV2Attention(nn.Module): q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) @@ -419,6 +424,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_kc = None self.w_vc = None + self.w_scale = None def forward( self, @@ -439,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module): -1, self.num_local_heads, self.qk_head_dim ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_nope_out = q_input[..., : self.kv_lora_rank] - torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) + + if self.w_kc.dtype == torch.float8_e4m3fn: + q_nope_val, q_nope_scale = input_to_float8( + q_nope.transpose(0, 1), torch.float8_e4m3fn + ) + q_nope_out = bmm_fp8( + q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 + ) + else: + q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) + q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] v_input = latent_cache[..., : self.kv_lora_rank] @@ -455,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module): attn_output = self.attn(q_input, k_input, v_input, input_metadata) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - attn_bmm_output = attn_output.new_empty( - q_len, self.num_local_heads, self.v_head_dim - ) - torch.bmm( - attn_output.transpose(0, 1), - self.w_vc, - out=attn_bmm_output.transpose(0, 1), - ) - attn_output = attn_bmm_output.flatten(1, 2) + if self.w_vc.dtype == torch.float8_e4m3fn: + attn_output_val, attn_output_scale = input_to_float8( + attn_output.transpose(0, 1), torch.float8_e4m3fn + ) + attn_bmm_output = bmm_fp8( + attn_output_val, + self.w_vc, + attn_output_scale, + self.w_scale, + torch.bfloat16, + ) + else: + attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) + attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) output, _ = self.o_proj(attn_output) return output @@ -717,8 +737,10 @@ class DeepseekV2ForCausalLM(nn.Module): w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.contiguous() - self_attn.w_vc = w_vc.transpose(1, 2).contiguous() + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if hasattr(self_attn.kv_b_proj, "weight_scale"): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale del self_attn.kv_b_proj