Auto set draft model path for MTP (#5793)
This commit is contained in:
@@ -47,6 +47,7 @@ class ModelConfig:
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.model_path = model_path
|
||||
@@ -85,6 +86,12 @@ class ModelConfig:
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
|
||||
@@ -71,6 +71,7 @@ class TpModelWorker:
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
is_draft_model=is_draft_worker,
|
||||
)
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
|
||||
@@ -692,9 +692,14 @@ class ModelRunner:
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
if self.use_mla_backend:
|
||||
num_layers = (
|
||||
self.model_config.num_hidden_layers
|
||||
if not self.is_draft_worker
|
||||
else self.model_config.hf_config.num_nextn_predict_layers
|
||||
)
|
||||
cell_size = (
|
||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||
* self.model_config.num_hidden_layers
|
||||
* num_layers
|
||||
* torch._utils._element_size(self.kv_cache_dtype)
|
||||
)
|
||||
else:
|
||||
@@ -809,7 +814,11 @@ class ModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
layer_num=(
|
||||
self.model_config.num_hidden_layers
|
||||
if not self.is_draft_worker
|
||||
else self.model_config.hf_config.num_nextn_predict_layers
|
||||
),
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
|
||||
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
||||
assert num_nextn_layers == self.config.num_hidden_layers
|
||||
else:
|
||||
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.n_share_experts_fusion > 0:
|
||||
logger.info(
|
||||
f"Cloning {self.n_share_experts_fusion} "
|
||||
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
||||
)
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale",
|
||||
]
|
||||
else:
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale_inv",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale_inv",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale_inv",
|
||||
]
|
||||
names_to_remove = []
|
||||
for suffix in suffix_list:
|
||||
shared_expert_weight_name = (
|
||||
f"model.layers.0.mlp.shared_experts.{suffix}"
|
||||
)
|
||||
for num_repeat in range(self.n_share_experts_fusion):
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.0."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + num_repeat}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
)
|
||||
names_to_remove += [shared_expert_weight_name]
|
||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||
)
|
||||
|
||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||
self.config.q_lora_rank is not None
|
||||
)
|
||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||
|
||||
nextn_layer_prefix = "model.layers.0"
|
||||
nextn_spec_weight_names = [
|
||||
"shared_head.norm",
|
||||
"eh_proj",
|
||||
"enorm",
|
||||
"hnorm",
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if not name.startswith(nextn_layer_prefix):
|
||||
continue
|
||||
|
||||
# Use shared head and embed weights from target model
|
||||
if "shared_head.head" in name or "embed_tokens" in name:
|
||||
continue
|
||||
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Handle fused_qkv_a_proj
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
cached_a_proj[name] = loaded_weight
|
||||
q_a_proj_name = (
|
||||
name
|
||||
if "q_a_proj" in name
|
||||
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||
)
|
||||
kv_a_proj_name = (
|
||||
name
|
||||
if "kv_a_proj_with_mqa" in name
|
||||
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||
)
|
||||
|
||||
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||
if (
|
||||
q_a_proj_name in cached_a_proj
|
||||
and kv_a_proj_name in cached_a_proj
|
||||
):
|
||||
|
||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||
)
|
||||
|
||||
param_name = name.replace(
|
||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, fused_weight)
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
self_attn = self.model.decoder.self_attn
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
# AWQ compatible
|
||||
if _is_cuda:
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
).T
|
||||
else:
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
).T
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
# This may affect the accuracy of fp8 model.
|
||||
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
if w.dtype == torch.int8:
|
||||
if hasattr(self.quant_config, "weight_block_size"):
|
||||
# block-wise int8 need it
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
|
||||
torch.bfloat16
|
||||
)
|
||||
else:
|
||||
# channel-wise int8 need it
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
||||
torch.bfloat16
|
||||
)
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
super().load_weights(weights, is_nextn=True)
|
||||
|
||||
|
||||
EntryClass = [DeepseekV3ForCausalLMNextN]
|
||||
|
||||
@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def post_load_weights(self):
|
||||
def post_load_weights(self, is_nextn=False):
|
||||
|
||||
# Perform post-processing after loading weights
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
self_attn = self.model.layers[layer_id].self_attn
|
||||
layer_ids = (
|
||||
range(self.config.num_hidden_layers)
|
||||
if not is_nextn
|
||||
else [self.config.num_hidden_layers]
|
||||
)
|
||||
for layer_id in layer_ids:
|
||||
self_attn = (
|
||||
self.model.layers[layer_id].self_attn
|
||||
if not is_nextn
|
||||
else self.model.decoder.self_attn
|
||||
)
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
# AWQ compatible
|
||||
if _is_cuda:
|
||||
@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.w_vc = w_vc.contiguous()
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||
if is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
||||
# compatible with old design
|
||||
nextn_layer_id = (
|
||||
0
|
||||
if self.config.num_hidden_layers == 1
|
||||
else self.config.num_hidden_layers
|
||||
)
|
||||
else:
|
||||
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"up_proj.weight_scale_inv",
|
||||
]
|
||||
names_to_remove = []
|
||||
for moe_layer in tqdm(
|
||||
|
||||
moe_layers = (
|
||||
range(
|
||||
self.config.first_k_dense_replace,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.moe_layer_freq,
|
||||
),
|
||||
)
|
||||
if not is_nextn
|
||||
else [nextn_layer_id]
|
||||
)
|
||||
|
||||
for moe_layer in tqdm(
|
||||
moe_layers,
|
||||
desc=f"Cloning {self.n_share_experts_fusion} "
|
||||
"replicas of the shared expert into MoE",
|
||||
):
|
||||
@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||
|
||||
if is_nextn:
|
||||
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
||||
nextn_spec_weight_names = [
|
||||
"shared_head.norm",
|
||||
"eh_proj",
|
||||
"enorm",
|
||||
"hnorm",
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
# TODO(HandH1998): Modify it when nextn is supported.
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
if not is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
else:
|
||||
if not name.startswith(nextn_layer_prefix):
|
||||
continue
|
||||
|
||||
# Use shared head and embed weights from target model
|
||||
if "shared_head.head" in name or "embed_tokens" in name:
|
||||
continue
|
||||
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
self.post_load_weights()
|
||||
self.post_load_weights(is_nextn=is_nextn)
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
@@ -22,7 +22,7 @@ import random
|
||||
import tempfile
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.utils import (
|
||||
configure_ipv6,
|
||||
@@ -333,6 +333,14 @@ class ServerArgs:
|
||||
"eagle speculative decoding."
|
||||
)
|
||||
|
||||
model_arch = get_model_arch(self)
|
||||
|
||||
# Auto set draft_model_path DeepSeek-V3/R1
|
||||
if self.speculative_draft_model_path is None and model_arch in [
|
||||
"DeepseekV3ForCausalLM"
|
||||
]:
|
||||
self.speculative_draft_model_path = self.model_path
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
assert (
|
||||
@@ -343,7 +351,7 @@ class ServerArgs:
|
||||
self.speculative_num_steps,
|
||||
self.speculative_eagle_topk,
|
||||
self.speculative_num_draft_tokens,
|
||||
) = auto_choose_speculative_params(self)
|
||||
) = auto_choose_speculative_params(model_arch)
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
|
||||
raise ValueError(self.help)
|
||||
|
||||
|
||||
def auto_choose_speculative_params(self: ServerArgs):
|
||||
def get_model_arch(args: ServerArgs):
|
||||
hf_config = get_config(
|
||||
args.model_path,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
revision=args.revision,
|
||||
model_override_args=json.loads(args.json_model_override_args),
|
||||
)
|
||||
return hf_config.architectures[0]
|
||||
|
||||
|
||||
def auto_choose_speculative_params(arch: str):
|
||||
"""
|
||||
Automatically choose the parameters for speculative decoding.
|
||||
|
||||
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||
"""
|
||||
config_path = os.path.join(self.model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"{config_path} is not found.")
|
||||
|
||||
config = json.load(open(config_path))
|
||||
|
||||
arch = config.get("architectures", ["Unknown"])[0]
|
||||
|
||||
if arch in ["LlamaForCausalLM"]:
|
||||
# The default value for llama
|
||||
return (5, 4, 8)
|
||||
|
||||
Reference in New Issue
Block a user