fix dummy-load deepseekv2 (#4535)
This commit is contained in:
@@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
|
|
||||||
|
# Model weight loading consists of two stages:
|
||||||
|
# 1. Initial weight loading.
|
||||||
|
# 2. Post-processing of weights, including assigning specific member variables.
|
||||||
|
# For `dummy_init`, only the second stage is required.
|
||||||
|
if hasattr(model, "post_load_weights"):
|
||||||
|
model.post_load_weights()
|
||||||
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1380,6 +1380,84 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def post_load_weights(self):
|
||||||
|
|
||||||
|
# Perform post-processing after loading weights
|
||||||
|
|
||||||
|
if not global_server_args_dict["disable_mla"]:
|
||||||
|
for layer_id in range(self.config.num_hidden_layers):
|
||||||
|
self_attn = self.model.layers[layer_id].self_attn
|
||||||
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||||
|
# AWQ compatible
|
||||||
|
if _is_cuda:
|
||||||
|
w = awq_dequantize(
|
||||||
|
self_attn.kv_b_proj.qweight,
|
||||||
|
self_attn.kv_b_proj.scales,
|
||||||
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
).T
|
||||||
|
else:
|
||||||
|
w = ops.awq_dequantize(
|
||||||
|
self_attn.kv_b_proj.qweight,
|
||||||
|
self_attn.kv_b_proj.scales,
|
||||||
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
).T
|
||||||
|
else:
|
||||||
|
w = self_attn.kv_b_proj.weight
|
||||||
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||||
|
# This may affect the accuracy of fp8 model.
|
||||||
|
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e4m3fnuz,
|
||||||
|
):
|
||||||
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
|
if weight_block_size is not None:
|
||||||
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||||
|
if _is_hip:
|
||||||
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight = w
|
||||||
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||||
|
|
||||||
|
w, scale = block_quant_to_tensor_quant(
|
||||||
|
weight, weight_scale, weight_block_size
|
||||||
|
)
|
||||||
|
self_attn.w_scale = scale
|
||||||
|
if w.dtype == torch.int8:
|
||||||
|
if hasattr(self.quant_config, "weight_block_size"):
|
||||||
|
# block-wise int8 need it
|
||||||
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
|
if weight_block_size is not None:
|
||||||
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||||
|
weight = w
|
||||||
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||||
|
w = int8_block_dequant(
|
||||||
|
weight, weight_scale, weight_block_size
|
||||||
|
).to(torch.bfloat16)
|
||||||
|
else:
|
||||||
|
# channel-wise int8 need it
|
||||||
|
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
||||||
|
torch.bfloat16
|
||||||
|
)
|
||||||
|
w_kc, w_vc = w.unflatten(
|
||||||
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||||
|
if (
|
||||||
|
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||||
|
and self_attn.w_scale is None
|
||||||
|
):
|
||||||
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||||
|
if _is_hip:
|
||||||
|
self_attn.w_scale *= 2.0
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
@@ -1504,79 +1582,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
if not global_server_args_dict["disable_mla"]:
|
self.post_load_weights()
|
||||||
for layer_id in range(self.config.num_hidden_layers):
|
|
||||||
self_attn = self.model.layers[layer_id].self_attn
|
|
||||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
|
||||||
# AWQ compatible
|
|
||||||
if _is_cuda:
|
|
||||||
w = awq_dequantize(
|
|
||||||
self_attn.kv_b_proj.qweight,
|
|
||||||
self_attn.kv_b_proj.scales,
|
|
||||||
self_attn.kv_b_proj.qzeros,
|
|
||||||
).T
|
|
||||||
else:
|
|
||||||
w = ops.awq_dequantize(
|
|
||||||
self_attn.kv_b_proj.qweight,
|
|
||||||
self_attn.kv_b_proj.scales,
|
|
||||||
self_attn.kv_b_proj.qzeros,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
).T
|
|
||||||
else:
|
|
||||||
w = self_attn.kv_b_proj.weight
|
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
|
||||||
# This may affect the accuracy of fp8 model.
|
|
||||||
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
|
||||||
torch.float8_e4m3fn,
|
|
||||||
torch.float8_e4m3fnuz,
|
|
||||||
):
|
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
|
||||||
if weight_block_size is not None:
|
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
||||||
if _is_hip:
|
|
||||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=w,
|
|
||||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
|
||||||
input_scale=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
weight = w
|
|
||||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
|
||||||
|
|
||||||
w, scale = block_quant_to_tensor_quant(
|
|
||||||
weight, weight_scale, weight_block_size
|
|
||||||
)
|
|
||||||
self_attn.w_scale = scale
|
|
||||||
if w.dtype == torch.int8:
|
|
||||||
if hasattr(self.quant_config, "weight_block_size"):
|
|
||||||
# block-wise int8 need it
|
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
|
||||||
if weight_block_size is not None:
|
|
||||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
||||||
weight = w
|
|
||||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
|
||||||
w = int8_block_dequant(
|
|
||||||
weight, weight_scale, weight_block_size
|
|
||||||
).to(torch.bfloat16)
|
|
||||||
else:
|
|
||||||
# channel-wise int8 need it
|
|
||||||
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
|
||||||
torch.bfloat16
|
|
||||||
)
|
|
||||||
w_kc, w_vc = w.unflatten(
|
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
||||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
||||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
|
||||||
if (
|
|
||||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
|
||||||
and self_attn.w_scale is None
|
|
||||||
):
|
|
||||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
|
||||||
if _is_hip:
|
|
||||||
self_attn.w_scale *= 2.0
|
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|||||||
Reference in New Issue
Block a user