Fix AWQ with enable MLA (#2364)

This commit is contained in:
Ke Bao
2024-12-06 00:46:48 +08:00
committed by GitHub
parent 2b0fc5941d
commit 4a63c181f1

View File

@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@@ -894,7 +895,19 @@ class DeepseekV2ForCausalLM(nn.Module):
if not global_server_args_dict["disable_mla"]: if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers): for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)