feat: fix fp8 for MLA and support bmm fp8 for DeepSeek V2 (#1285)
Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
@@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from flashinfer import bmm_fp8
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
@@ -161,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|||||||
return 0.1 * mscale * math.log(scale) + 1.0
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Attention(nn.Module):
|
class DeepseekV2Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -255,11 +265,6 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
self.scaling = self.scaling * mscale * mscale
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
# self.attn = Attention(self.num_heads,
|
|
||||||
# self.qk_head_dim,
|
|
||||||
# self.scaling,
|
|
||||||
# num_kv_heads=self.num_heads)
|
|
||||||
|
|
||||||
# TODO, support head_size 192
|
# TODO, support head_size 192
|
||||||
self.attn = RadixAttention(
|
self.attn = RadixAttention(
|
||||||
self.num_local_heads,
|
self.num_local_heads,
|
||||||
@@ -283,7 +288,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
q = self.q_proj(hidden_states)[0].view(
|
q = self.q_proj(hidden_states)[0].view(
|
||||||
-1, self.num_local_heads, self.qk_head_dim
|
-1, self.num_local_heads, self.qk_head_dim
|
||||||
)
|
)
|
||||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
latent_cache = latent_cache.unsqueeze(1)
|
latent_cache = latent_cache.unsqueeze(1)
|
||||||
@@ -419,6 +424,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
self.w_kc = None
|
self.w_kc = None
|
||||||
self.w_vc = None
|
self.w_vc = None
|
||||||
|
self.w_scale = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -439,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
-1, self.num_local_heads, self.qk_head_dim
|
-1, self.num_local_heads, self.qk_head_dim
|
||||||
)
|
)
|
||||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
q_nope_out = q_input[..., : self.kv_lora_rank]
|
|
||||||
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
if self.w_kc.dtype == torch.float8_e4m3fn:
|
||||||
|
q_nope_val, q_nope_scale = input_to_float8(
|
||||||
|
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
q_nope_out = bmm_fp8(
|
||||||
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||||
|
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||||
|
|
||||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
v_input = latent_cache[..., : self.kv_lora_rank]
|
v_input = latent_cache[..., : self.kv_lora_rank]
|
||||||
@@ -455,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
attn_bmm_output = attn_output.new_empty(
|
|
||||||
q_len, self.num_local_heads, self.v_head_dim
|
|
||||||
)
|
|
||||||
torch.bmm(
|
|
||||||
attn_output.transpose(0, 1),
|
|
||||||
self.w_vc,
|
|
||||||
out=attn_bmm_output.transpose(0, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_bmm_output.flatten(1, 2)
|
if self.w_vc.dtype == torch.float8_e4m3fn:
|
||||||
|
attn_output_val, attn_output_scale = input_to_float8(
|
||||||
|
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
attn_bmm_output = bmm_fp8(
|
||||||
|
attn_output_val,
|
||||||
|
self.w_vc,
|
||||||
|
attn_output_scale,
|
||||||
|
self.w_scale,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
||||||
|
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -717,8 +737,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
self_attn.w_kc = w_kc.contiguous()
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
self_attn.w_vc = w_vc.transpose(1, 2).contiguous()
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||||
|
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
||||||
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||||
del self_attn.kv_b_proj
|
del self_attn.kv_b_proj
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user