diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index aa47daa30..2b87d91d4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -16,6 +16,7 @@ # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py """Inference-only DeepseekV2 model.""" +import concurrent.futures import logging import os from enum import IntEnum, auto @@ -2436,154 +2437,174 @@ class DeepseekV2ForCausalLM(nn.Module): assert self.num_fused_shared_experts == 1 log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") - params_dict = dict(self.named_parameters()) - weight_names = [] - for name, loaded_weight in weights: - if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: - name = name.replace( - "mlp.shared_experts", - f"mlp.experts.{self.config.n_routed_experts}", - ) + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) - weight_names.append(name) + weight_names.append(name) - if not is_nextn: - if hasattr(self.config, "num_nextn_predict_layers"): - num_nextn_layers = self.config.num_nextn_predict_layers - if num_nextn_layers > 0 and name.startswith("model.layers"): - name_list = name.split(".") - if ( - len(name_list) >= 3 - and int(name_list[2]) >= self.config.num_hidden_layers - ): - continue - else: - if not name.startswith(nextn_layer_prefix): - continue + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) >= self.config.num_hidden_layers + ): + continue + else: + if not name.startswith(nextn_layer_prefix): + continue - # Use shared head and embed weights from target model - if "shared_head.head" in name or "embed_tokens" in name: - continue + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue - is_decoder = True - # For nextn specific weights - for weight_name in nextn_spec_weight_names: - if weight_name in name: - name = name.replace(nextn_layer_prefix, "model") - is_decoder = False - break - # For decoder layer weights - if is_decoder: - name = name.replace(nextn_layer_prefix, "model.decoder") + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: + if "rotary_emb.inv_freq" in name: continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if fuse_qkv_a_proj and ( - "q_a_proj" in name or "kv_a_proj_with_mqa" in name - ): - cached_a_proj[name] = loaded_weight - q_a_proj_name = ( - name - if "q_a_proj" in name - else name.replace("kv_a_proj_with_mqa", "q_a_proj") + param = params_dict[name] + weight_loader = param.weight_loader + futures.append( + executor.submit(weight_loader, param, loaded_weight, shard_id) + ) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + futures.append( + executor.submit( + weight_loader, + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) ) - kv_a_proj_name = ( - name - if "kv_a_proj_with_mqa" in name - else name.replace("q_a_proj", "kv_a_proj_with_mqa") - ) - - # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter - if ( - q_a_proj_name in cached_a_proj - and kv_a_proj_name in cached_a_proj + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name ): - q_a_proj_weight = cached_a_proj[q_a_proj_name] - kv_a_proj_weight = cached_a_proj[kv_a_proj_name] - cat_dim = 0 - if self.quant_config is not None and ( - self.quant_config.get_name() == "awq" - or self.quant_config.get_name() == "moe_wna16" - ): - cat_dim = 1 - fused_weight = torch.cat( - [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim - ) - param_name = ( - name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name if "q_a_proj" in name - else name.replace( - "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa" - ) + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") ) - param = params_dict[param_name] + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + param_name = ( + name.replace( + "q_a_proj", "fused_qkv_a_proj_with_mqa" + ) + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", + "fused_qkv_a_proj_with_mqa", + ) + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + futures.append( + executor.submit(weight_loader, param, fused_weight) + ) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + logger.warning(f"{name} not found in params_dict.") + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) - weight_loader(param, fused_weight) - cached_a_proj.pop(q_a_proj_name) - cached_a_proj.pop(kv_a_proj_name) - else: - if ( - "k_scale" in name or "v_scale" in name - ) and name not in params_dict: - # modelopt attn kv scale is named differently - for scale in ["k_scale", "v_scale"]: - if scale in name: - name = name.replace(f"{scale[0]}_proj", "attn_mqa") - break - if name not in params_dict: - # modelopt ckpt contains not needed weights for MTP module: - # model.decoder.self_attn.attn_mqa.v_scale and - # model.decoder.self_attn.attn_mqa.k_scale - logger.warning(f"{name} not found in params_dict.") - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + futures.append( + executor.submit(weight_loader, param, loaded_weight) + ) + + # Wait for all tasks to complete and raise any exceptions. + for future in concurrent.futures.as_completed(futures): + future.result() self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)