Auto set draft model path for MTP (#5793)
This commit is contained in:
@@ -47,6 +47,7 @@ class ModelConfig:
|
|||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
override_config_file: Optional[str] = None,
|
override_config_file: Optional[str] = None,
|
||||||
|
is_draft_model: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
@@ -85,6 +86,12 @@ class ModelConfig:
|
|||||||
else:
|
else:
|
||||||
enable_multimodal = True
|
enable_multimodal = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
is_draft_model
|
||||||
|
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||||
|
):
|
||||||
|
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
|
||||||
|
|
||||||
# Check model type
|
# Check model type
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.hf_config.architectures, is_embedding
|
self.hf_config.architectures, is_embedding
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ class TpModelWorker:
|
|||||||
enable_multimodal=server_args.enable_multimodal,
|
enable_multimodal=server_args.enable_multimodal,
|
||||||
dtype=server_args.dtype,
|
dtype=server_args.dtype,
|
||||||
quantization=server_args.quantization,
|
quantization=server_args.quantization,
|
||||||
|
is_draft_model=is_draft_worker,
|
||||||
)
|
)
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
|
|||||||
@@ -692,9 +692,14 @@ class ModelRunner:
|
|||||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
if self.use_mla_backend:
|
if self.use_mla_backend:
|
||||||
|
num_layers = (
|
||||||
|
self.model_config.num_hidden_layers
|
||||||
|
if not self.is_draft_worker
|
||||||
|
else self.model_config.hf_config.num_nextn_predict_layers
|
||||||
|
)
|
||||||
cell_size = (
|
cell_size = (
|
||||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||||
* self.model_config.num_hidden_layers
|
* num_layers
|
||||||
* torch._utils._element_size(self.kv_cache_dtype)
|
* torch._utils._element_size(self.kv_cache_dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -809,7 +814,11 @@ class ModelRunner:
|
|||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=(
|
||||||
|
self.model_config.num_hidden_layers
|
||||||
|
if not self.is_draft_worker
|
||||||
|
else self.model_config.hf_config.num_nextn_predict_layers
|
||||||
|
),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
super().load_weights(weights, is_nextn=True)
|
||||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
|
||||||
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
|
||||||
assert num_nextn_layers == self.config.num_hidden_layers
|
|
||||||
else:
|
|
||||||
raise ValueError("num_nextn_predict_layers is not in the config")
|
|
||||||
|
|
||||||
stacked_params_mapping = [
|
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
if self.n_share_experts_fusion > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Cloning {self.n_share_experts_fusion} "
|
|
||||||
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
|
||||||
)
|
|
||||||
weights_list = list(weights)
|
|
||||||
weights_dict = dict(weights_list)
|
|
||||||
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.weight",
|
|
||||||
"down_proj.weight_scale",
|
|
||||||
"gate_proj.weight",
|
|
||||||
"gate_proj.weight_scale",
|
|
||||||
"up_proj.weight",
|
|
||||||
"up_proj.weight_scale",
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
suffix_list = [
|
|
||||||
"down_proj.weight",
|
|
||||||
"down_proj.weight_scale_inv",
|
|
||||||
"gate_proj.weight",
|
|
||||||
"gate_proj.weight_scale_inv",
|
|
||||||
"up_proj.weight",
|
|
||||||
"up_proj.weight_scale_inv",
|
|
||||||
]
|
|
||||||
names_to_remove = []
|
|
||||||
for suffix in suffix_list:
|
|
||||||
shared_expert_weight_name = (
|
|
||||||
f"model.layers.0.mlp.shared_experts.{suffix}"
|
|
||||||
)
|
|
||||||
for num_repeat in range(self.n_share_experts_fusion):
|
|
||||||
weights_list.append(
|
|
||||||
(
|
|
||||||
f"model.layers.0."
|
|
||||||
f"mlp.experts."
|
|
||||||
f"{self.config.n_routed_experts + num_repeat}"
|
|
||||||
f".{suffix}",
|
|
||||||
weights_dict[shared_expert_weight_name],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
names_to_remove += [shared_expert_weight_name]
|
|
||||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
|
||||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_up_proj_name="up_proj",
|
|
||||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
|
||||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
|
||||||
self.config.q_lora_rank is not None
|
|
||||||
)
|
|
||||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
|
||||||
|
|
||||||
nextn_layer_prefix = "model.layers.0"
|
|
||||||
nextn_spec_weight_names = [
|
|
||||||
"shared_head.norm",
|
|
||||||
"eh_proj",
|
|
||||||
"enorm",
|
|
||||||
"hnorm",
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if not name.startswith(nextn_layer_prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Use shared head and embed weights from target model
|
|
||||||
if "shared_head.head" in name or "embed_tokens" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
is_decoder = True
|
|
||||||
# For nextn specific weights
|
|
||||||
for weight_name in nextn_spec_weight_names:
|
|
||||||
if weight_name in name:
|
|
||||||
name = name.replace(nextn_layer_prefix, "model")
|
|
||||||
is_decoder = False
|
|
||||||
break
|
|
||||||
# For decoder layer weights
|
|
||||||
if is_decoder:
|
|
||||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
||||||
# Since we handle the experts below in expert_params_mapping,
|
|
||||||
# we need to skip here BEFORE we update the name, otherwise
|
|
||||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
||||||
# will then be updated below in expert_params_mapping
|
|
||||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
||||||
if ("mlp.experts." in name) and name not in params_dict:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for mapping in expert_params_mapping:
|
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(
|
|
||||||
param,
|
|
||||||
loaded_weight,
|
|
||||||
name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=expert_id,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle fused_qkv_a_proj
|
|
||||||
if fuse_qkv_a_proj and (
|
|
||||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
|
||||||
):
|
|
||||||
cached_a_proj[name] = loaded_weight
|
|
||||||
q_a_proj_name = (
|
|
||||||
name
|
|
||||||
if "q_a_proj" in name
|
|
||||||
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
|
||||||
)
|
|
||||||
kv_a_proj_name = (
|
|
||||||
name
|
|
||||||
if "kv_a_proj_with_mqa" in name
|
|
||||||
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
|
||||||
)
|
|
||||||
|
|
||||||
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
|
||||||
if (
|
|
||||||
q_a_proj_name in cached_a_proj
|
|
||||||
and kv_a_proj_name in cached_a_proj
|
|
||||||
):
|
|
||||||
|
|
||||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
|
||||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
|
||||||
fused_weight = torch.cat(
|
|
||||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
param_name = name.replace(
|
|
||||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
|
||||||
)
|
|
||||||
param = params_dict[param_name]
|
|
||||||
|
|
||||||
weight_loader = getattr(
|
|
||||||
param, "weight_loader", default_weight_loader
|
|
||||||
)
|
|
||||||
weight_loader(param, fused_weight)
|
|
||||||
cached_a_proj.pop(q_a_proj_name)
|
|
||||||
cached_a_proj.pop(kv_a_proj_name)
|
|
||||||
else:
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(
|
|
||||||
param, "weight_loader", default_weight_loader
|
|
||||||
)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|
||||||
self_attn = self.model.decoder.self_attn
|
|
||||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
|
||||||
# AWQ compatible
|
|
||||||
if _is_cuda:
|
|
||||||
w = awq_dequantize(
|
|
||||||
self_attn.kv_b_proj.qweight,
|
|
||||||
self_attn.kv_b_proj.scales,
|
|
||||||
self_attn.kv_b_proj.qzeros,
|
|
||||||
).T
|
|
||||||
else:
|
|
||||||
w = awq_dequantize(
|
|
||||||
self_attn.kv_b_proj.qweight,
|
|
||||||
self_attn.kv_b_proj.scales,
|
|
||||||
self_attn.kv_b_proj.qzeros,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
).T
|
|
||||||
else:
|
|
||||||
w = self_attn.kv_b_proj.weight
|
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
|
||||||
# This may affect the accuracy of fp8 model.
|
|
||||||
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
|
||||||
torch.float8_e4m3fn,
|
|
||||||
torch.float8_e4m3fnuz,
|
|
||||||
):
|
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
|
||||||
if weight_block_size is not None:
|
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
||||||
if _is_hip:
|
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=w,
|
|
||||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
|
||||||
input_scale=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
weight = w
|
|
||||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
|
||||||
|
|
||||||
w, scale = block_quant_to_tensor_quant(
|
|
||||||
weight, weight_scale, weight_block_size
|
|
||||||
)
|
|
||||||
self_attn.w_scale = scale
|
|
||||||
if w.dtype == torch.int8:
|
|
||||||
if hasattr(self.quant_config, "weight_block_size"):
|
|
||||||
# block-wise int8 need it
|
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
|
||||||
if weight_block_size is not None:
|
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
||||||
weight = w
|
|
||||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
|
||||||
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
|
|
||||||
torch.bfloat16
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# channel-wise int8 need it
|
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
|
||||||
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
|
||||||
torch.bfloat16
|
|
||||||
)
|
|
||||||
w_kc, w_vc = w.unflatten(
|
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
||||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
||||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
|
||||||
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
|
|
||||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
|
||||||
if _is_hip:
|
|
||||||
self_attn.w_scale *= 2.0
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [DeepseekV3ForCausalLMNextN]
|
EntryClass = [DeepseekV3ForCausalLMNextN]
|
||||||
|
|||||||
@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, is_nextn=False):
|
||||||
|
|
||||||
# Perform post-processing after loading weights
|
# Perform post-processing after loading weights
|
||||||
for layer_id in range(self.config.num_hidden_layers):
|
layer_ids = (
|
||||||
self_attn = self.model.layers[layer_id].self_attn
|
range(self.config.num_hidden_layers)
|
||||||
|
if not is_nextn
|
||||||
|
else [self.config.num_hidden_layers]
|
||||||
|
)
|
||||||
|
for layer_id in layer_ids:
|
||||||
|
self_attn = (
|
||||||
|
self.model.layers[layer_id].self_attn
|
||||||
|
if not is_nextn
|
||||||
|
else self.model.decoder.self_attn
|
||||||
|
)
|
||||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||||
# AWQ compatible
|
# AWQ compatible
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self_attn.w_vc = w_vc.contiguous()
|
self_attn.w_vc = w_vc.contiguous()
|
||||||
self_attn.use_deep_gemm_bmm = True
|
self_attn.use_deep_gemm_bmm = True
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||||
|
if is_nextn:
|
||||||
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||||
|
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
||||||
|
# compatible with old design
|
||||||
|
nextn_layer_id = (
|
||||||
|
0
|
||||||
|
if self.config.num_hidden_layers == 1
|
||||||
|
else self.config.num_hidden_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||||
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
"up_proj.weight_scale_inv",
|
"up_proj.weight_scale_inv",
|
||||||
]
|
]
|
||||||
names_to_remove = []
|
names_to_remove = []
|
||||||
for moe_layer in tqdm(
|
|
||||||
|
moe_layers = (
|
||||||
range(
|
range(
|
||||||
self.config.first_k_dense_replace,
|
self.config.first_k_dense_replace,
|
||||||
self.config.num_hidden_layers,
|
self.config.num_hidden_layers,
|
||||||
self.config.moe_layer_freq,
|
self.config.moe_layer_freq,
|
||||||
),
|
)
|
||||||
|
if not is_nextn
|
||||||
|
else [nextn_layer_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
for moe_layer in tqdm(
|
||||||
|
moe_layers,
|
||||||
desc=f"Cloning {self.n_share_experts_fusion} "
|
desc=f"Cloning {self.n_share_experts_fusion} "
|
||||||
"replicas of the shared expert into MoE",
|
"replicas of the shared expert into MoE",
|
||||||
):
|
):
|
||||||
@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||||
|
|
||||||
|
if is_nextn:
|
||||||
|
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
||||||
|
nextn_spec_weight_names = [
|
||||||
|
"shared_head.norm",
|
||||||
|
"eh_proj",
|
||||||
|
"enorm",
|
||||||
|
"hnorm",
|
||||||
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
# TODO(HandH1998): Modify it when nextn is supported.
|
if not is_nextn:
|
||||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||||
name_list = name.split(".")
|
name_list = name.split(".")
|
||||||
if (
|
if (
|
||||||
len(name_list) >= 3
|
len(name_list) >= 3
|
||||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
if not name.startswith(nextn_layer_prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Use shared head and embed weights from target model
|
||||||
|
if "shared_head.head" in name or "embed_tokens" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_decoder = True
|
||||||
|
# For nextn specific weights
|
||||||
|
for weight_name in nextn_spec_weight_names:
|
||||||
|
if weight_name in name:
|
||||||
|
name = name.replace(nextn_layer_prefix, "model")
|
||||||
|
is_decoder = False
|
||||||
|
break
|
||||||
|
# For decoder layer weights
|
||||||
|
if is_decoder:
|
||||||
|
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
self.post_load_weights()
|
self.post_load_weights(is_nextn=is_nextn)
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_ipv6,
|
configure_ipv6,
|
||||||
@@ -333,6 +333,14 @@ class ServerArgs:
|
|||||||
"eagle speculative decoding."
|
"eagle speculative decoding."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_arch = get_model_arch(self)
|
||||||
|
|
||||||
|
# Auto set draft_model_path DeepSeek-V3/R1
|
||||||
|
if self.speculative_draft_model_path is None and model_arch in [
|
||||||
|
"DeepseekV3ForCausalLM"
|
||||||
|
]:
|
||||||
|
self.speculative_draft_model_path = self.model_path
|
||||||
|
|
||||||
# Auto choose parameters
|
# Auto choose parameters
|
||||||
if self.speculative_num_steps is None:
|
if self.speculative_num_steps is None:
|
||||||
assert (
|
assert (
|
||||||
@@ -343,7 +351,7 @@ class ServerArgs:
|
|||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.speculative_eagle_topk,
|
self.speculative_eagle_topk,
|
||||||
self.speculative_num_draft_tokens,
|
self.speculative_num_draft_tokens,
|
||||||
) = auto_choose_speculative_params(self)
|
) = auto_choose_speculative_params(model_arch)
|
||||||
|
|
||||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||||
self.speculative_eagle_topk = 1
|
self.speculative_eagle_topk = 1
|
||||||
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
|
|||||||
raise ValueError(self.help)
|
raise ValueError(self.help)
|
||||||
|
|
||||||
|
|
||||||
def auto_choose_speculative_params(self: ServerArgs):
|
def get_model_arch(args: ServerArgs):
|
||||||
|
hf_config = get_config(
|
||||||
|
args.model_path,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
revision=args.revision,
|
||||||
|
model_override_args=json.loads(args.json_model_override_args),
|
||||||
|
)
|
||||||
|
return hf_config.architectures[0]
|
||||||
|
|
||||||
|
|
||||||
|
def auto_choose_speculative_params(arch: str):
|
||||||
"""
|
"""
|
||||||
Automatically choose the parameters for speculative decoding.
|
Automatically choose the parameters for speculative decoding.
|
||||||
|
|
||||||
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||||
"""
|
"""
|
||||||
config_path = os.path.join(self.model_path, "config.json")
|
|
||||||
if not os.path.exists(config_path):
|
|
||||||
raise ValueError(f"{config_path} is not found.")
|
|
||||||
|
|
||||||
config = json.load(open(config_path))
|
|
||||||
|
|
||||||
arch = config.get("architectures", ["Unknown"])[0]
|
|
||||||
|
|
||||||
if arch in ["LlamaForCausalLM"]:
|
if arch in ["LlamaForCausalLM"]:
|
||||||
# The default value for llama
|
# The default value for llama
|
||||||
return (5, 4, 8)
|
return (5, 4, 8)
|
||||||
|
|||||||
Reference in New Issue
Block a user