Fix AWQ with enable MLA (#2364)
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -894,7 +895,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
if not global_server_args_dict["disable_mla"]:
|
if not global_server_args_dict["disable_mla"]:
|
||||||
for layer_id in range(self.config.num_hidden_layers):
|
for layer_id in range(self.config.num_hidden_layers):
|
||||||
self_attn = self.model.layers[layer_id].self_attn
|
self_attn = self.model.layers[layer_id].self_attn
|
||||||
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||||
|
# AWQ compatible
|
||||||
|
w = ops.awq_dequantize(
|
||||||
|
self_attn.kv_b_proj.qweight,
|
||||||
|
self_attn.kv_b_proj.scales,
|
||||||
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
).T
|
||||||
|
else:
|
||||||
|
w = self_attn.kv_b_proj.weight
|
||||||
|
w_kc, w_vc = w.unflatten(
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user