concurrently load weights of DeepseekV2ForCausalLM (#7943)
Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||||
"""Inference-only DeepseekV2 model."""
|
"""Inference-only DeepseekV2 model."""
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
@@ -2436,6 +2437,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
assert self.num_fused_shared_experts == 1
|
assert self.num_fused_shared_experts == 1
|
||||||
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
weight_names = []
|
weight_names = []
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@@ -2496,7 +2499,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
futures.append(
|
||||||
|
executor.submit(weight_loader, param, loaded_weight, shard_id)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in expert_params_mapping:
|
for mapping in expert_params_mapping:
|
||||||
@@ -2506,13 +2511,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
weight_loader,
|
||||||
param,
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
@@ -2550,10 +2558,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
||||||
)
|
)
|
||||||
param_name = (
|
param_name = (
|
||||||
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
name.replace(
|
||||||
|
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||||
|
)
|
||||||
if "q_a_proj" in name
|
if "q_a_proj" in name
|
||||||
else name.replace(
|
else name.replace(
|
||||||
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
"kv_a_proj_with_mqa",
|
||||||
|
"fused_qkv_a_proj_with_mqa",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
param = params_dict[param_name]
|
param = params_dict[param_name]
|
||||||
@@ -2561,7 +2572,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
)
|
)
|
||||||
weight_loader(param, fused_weight)
|
futures.append(
|
||||||
|
executor.submit(weight_loader, param, fused_weight)
|
||||||
|
)
|
||||||
cached_a_proj.pop(q_a_proj_name)
|
cached_a_proj.pop(q_a_proj_name)
|
||||||
cached_a_proj.pop(kv_a_proj_name)
|
cached_a_proj.pop(kv_a_proj_name)
|
||||||
else:
|
else:
|
||||||
@@ -2571,7 +2584,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
# modelopt attn kv scale is named differently
|
# modelopt attn kv scale is named differently
|
||||||
for scale in ["k_scale", "v_scale"]:
|
for scale in ["k_scale", "v_scale"]:
|
||||||
if scale in name:
|
if scale in name:
|
||||||
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
name = name.replace(
|
||||||
|
f"{scale[0]}_proj", "attn_mqa"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
# modelopt ckpt contains not needed weights for MTP module:
|
# modelopt ckpt contains not needed weights for MTP module:
|
||||||
@@ -2583,7 +2598,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
futures.append(
|
||||||
|
executor.submit(weight_loader, param, loaded_weight)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete and raise any exceptions.
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
future.result()
|
||||||
|
|
||||||
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user