@@ -43,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
input_to_float8,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
@@ -186,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -869,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
# TODO(HandH1998): Modify it when nextn is supported.
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@@ -933,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
).T
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
# This may affect the accuracy of fp8 model.
|
||||
if (
|
||||
hasattr(self.quant_config, "weight_block_size")
|
||||
and w.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
||||
if (
|
||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
|
||||
|
||||
EntryClass = DeepseekV2ForCausalLM
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|
||||
|
||||
Reference in New Issue
Block a user