From 7ed77d6b9e5cd41ca8cc217d1afe062f147302fd Mon Sep 17 00:00:00 2001 From: inkcherry Date: Sat, 5 Apr 2025 06:22:37 +0800 Subject: [PATCH] fix dummy-load deepseekv2 (#4535) --- python/sglang/srt/model_loader/loader.py | 8 ++ python/sglang/srt/models/deepseek_v2.py | 152 ++++++++++++----------- 2 files changed, 87 insertions(+), 73 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 07e582cc5..7580967a6 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader): # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) + + # Model weight loading consists of two stages: + # 1. Initial weight loading. + # 2. Post-processing of weights, including assigning specific member variables. + # For `dummy_init`, only the second stage is required. + if hasattr(model, "post_load_weights"): + model.post_load_weights() + return model.eval() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8b859bfdc..91b3495f8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1380,6 +1380,84 @@ class DeepseekV2ForCausalLM(nn.Module): input_ids, hidden_states, self.lm_head, forward_batch ) + def post_load_weights(self): + + # Perform post-processing after loading weights + + if not global_server_args_dict["disable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if _is_hip: + self_attn.w_scale *= 2.0 + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -1504,79 +1582,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - if not global_server_args_dict["disable_mla"]: - for layer_id in range(self.config.num_hidden_layers): - self_attn = self.model.layers[layer_id].self_attn - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = ops.awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T - else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 - ) - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = self_attn.kv_b_proj.weight_scale - if _is_hip: - self_attn.w_scale *= 2.0 + self.post_load_weights() def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight