concurrently load weights of DeepseekV2ForCausalLM (#7943)
Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
from enum import IntEnum, auto
|
||||
@@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
assert self.num_fused_shared_experts == 1
|
||||
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
weight_names = []
|
||||
for name, loaded_weight in weights:
|
||||
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
||||
name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts}",
|
||||
)
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
params_dict = dict(self.named_parameters())
|
||||
weight_names = []
|
||||
for name, loaded_weight in weights:
|
||||
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
||||
name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts}",
|
||||
)
|
||||
|
||||
weight_names.append(name)
|
||||
weight_names.append(name)
|
||||
|
||||
if not is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
else:
|
||||
if not name.startswith(nextn_layer_prefix):
|
||||
continue
|
||||
if not is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
||||
name_list = name.split(".")
|
||||
if (
|
||||
len(name_list) >= 3
|
||||
and int(name_list[2]) >= self.config.num_hidden_layers
|
||||
):
|
||||
continue
|
||||
else:
|
||||
if not name.startswith(nextn_layer_prefix):
|
||||
continue
|
||||
|
||||
# Use shared head and embed weights from target model
|
||||
if "shared_head.head" in name or "embed_tokens" in name:
|
||||
continue
|
||||
# Use shared head and embed weights from target model
|
||||
if "shared_head.head" in name or "embed_tokens" in name:
|
||||
continue
|
||||
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
is_decoder = True
|
||||
# For nextn specific weights
|
||||
for weight_name in nextn_spec_weight_names:
|
||||
if weight_name in name:
|
||||
name = name.replace(nextn_layer_prefix, "model")
|
||||
is_decoder = False
|
||||
break
|
||||
# For decoder layer weights
|
||||
if is_decoder:
|
||||
name = name.replace(nextn_layer_prefix, "model.decoder")
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
cached_a_proj[name] = loaded_weight
|
||||
q_a_proj_name = (
|
||||
name
|
||||
if "q_a_proj" in name
|
||||
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
futures.append(
|
||||
executor.submit(weight_loader, param, loaded_weight, shard_id)
|
||||
)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
futures.append(
|
||||
executor.submit(
|
||||
weight_loader,
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
)
|
||||
kv_a_proj_name = (
|
||||
name
|
||||
if "kv_a_proj_with_mqa" in name
|
||||
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||
)
|
||||
|
||||
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||
if (
|
||||
q_a_proj_name in cached_a_proj
|
||||
and kv_a_proj_name in cached_a_proj
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||
cat_dim = 0
|
||||
if self.quant_config is not None and (
|
||||
self.quant_config.get_name() == "awq"
|
||||
or self.quant_config.get_name() == "moe_wna16"
|
||||
):
|
||||
cat_dim = 1
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
||||
)
|
||||
param_name = (
|
||||
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||
cached_a_proj[name] = loaded_weight
|
||||
q_a_proj_name = (
|
||||
name
|
||||
if "q_a_proj" in name
|
||||
else name.replace(
|
||||
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||
)
|
||||
kv_a_proj_name = (
|
||||
name
|
||||
if "kv_a_proj_with_mqa" in name
|
||||
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||
if (
|
||||
q_a_proj_name in cached_a_proj
|
||||
and kv_a_proj_name in cached_a_proj
|
||||
):
|
||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||
cat_dim = 0
|
||||
if self.quant_config is not None and (
|
||||
self.quant_config.get_name() == "awq"
|
||||
or self.quant_config.get_name() == "moe_wna16"
|
||||
):
|
||||
cat_dim = 1
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
||||
)
|
||||
param_name = (
|
||||
name.replace(
|
||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
if "q_a_proj" in name
|
||||
else name.replace(
|
||||
"kv_a_proj_with_mqa",
|
||||
"fused_qkv_a_proj_with_mqa",
|
||||
)
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
futures.append(
|
||||
executor.submit(weight_loader, param, fused_weight)
|
||||
)
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
if (
|
||||
"k_scale" in name or "v_scale" in name
|
||||
) and name not in params_dict:
|
||||
# modelopt attn kv scale is named differently
|
||||
for scale in ["k_scale", "v_scale"]:
|
||||
if scale in name:
|
||||
name = name.replace(
|
||||
f"{scale[0]}_proj", "attn_mqa"
|
||||
)
|
||||
break
|
||||
if name not in params_dict:
|
||||
# modelopt ckpt contains not needed weights for MTP module:
|
||||
# model.decoder.self_attn.attn_mqa.v_scale and
|
||||
# model.decoder.self_attn.attn_mqa.k_scale
|
||||
logger.warning(f"{name} not found in params_dict.")
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, fused_weight)
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
if (
|
||||
"k_scale" in name or "v_scale" in name
|
||||
) and name not in params_dict:
|
||||
# modelopt attn kv scale is named differently
|
||||
for scale in ["k_scale", "v_scale"]:
|
||||
if scale in name:
|
||||
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
||||
break
|
||||
if name not in params_dict:
|
||||
# modelopt ckpt contains not needed weights for MTP module:
|
||||
# model.decoder.self_attn.attn_mqa.v_scale and
|
||||
# model.decoder.self_attn.attn_mqa.k_scale
|
||||
logger.warning(f"{name} not found in params_dict.")
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
futures.append(
|
||||
executor.submit(weight_loader, param, loaded_weight)
|
||||
)
|
||||
|
||||
# Wait for all tasks to complete and raise any exceptions.
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
|
||||
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user