Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>
### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints:
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
import sys
|
|
from typing import Any
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class ImportPatchDecorator:
|
|
"""Import patch decorator"""
|
|
|
|
_patches: dict[str, Any] = {}
|
|
|
|
@classmethod
|
|
def register(cls, module_name):
|
|
"""Decorator for registering module patches"""
|
|
|
|
def decorator(func):
|
|
cls._patches[module_name] = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
@classmethod
|
|
def apply_patches(cls):
|
|
"""Apply all patches"""
|
|
for module_name, patch_func in cls._patches.items():
|
|
if module_name in sys.modules:
|
|
module = sys.modules[module_name]
|
|
try:
|
|
patch_func(module)
|
|
except Exception as e:
|
|
logger.error(f"Patch application failed {module_name}: {e}")
|
|
|
|
|
|
@ImportPatchDecorator.register("vllm.model_executor.models.deepseek_v2")
|
|
def patch_deepseek(module):
|
|
ori_maybe_remap_kv_scale_name = maybe_remap_kv_scale_name
|
|
|
|
def new_remap(name: str, params_dict: dict):
|
|
name = ori_maybe_remap_kv_scale_name(name, params_dict)
|
|
|
|
replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"]
|
|
|
|
for scale_name in replace_scale_names:
|
|
if name.endswith(scale_name):
|
|
remap_name = name.replace(scale_name, f"mla_attn.mla_attn.{scale_name}")
|
|
if remap_name in params_dict:
|
|
return remap_name
|
|
else:
|
|
return remap_name.replace(".mla_attn", "")
|
|
|
|
return name
|
|
|
|
if hasattr(module, "maybe_remap_kv_scale_name"):
|
|
module._original_maybe_remap_kv_scale_name = module.maybe_remap_kv_scale_name
|
|
module.maybe_remap_kv_scale_name = new_remap
|
|
|
|
|
|
@ImportPatchDecorator.register("vllm.model_executor.model_loader.weight_utils")
|
|
def patch_weight_utils(module):
|
|
if "vllm.model_executor.models.deepseek_v2" in sys.modules:
|
|
deepseek = sys.modules["vllm.model_executor.models.deepseek_v2"]
|
|
if hasattr(deepseek, "maybe_remap_kv_scale_name"):
|
|
module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name
|
|
|
|
|
|
original_import = __builtins__["__import__"] # type: ignore
|
|
|
|
|
|
def patched_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
module = original_import(name, globals, locals, fromlist, level)
|
|
|
|
if name in ImportPatchDecorator._patches:
|
|
try:
|
|
ImportPatchDecorator._patches[name](module)
|
|
except Exception as e:
|
|
logger.error(f"Patch application failed during import {name}: {e}")
|
|
|
|
return module
|
|
|
|
|
|
__builtins__["__import__"] = patched_import
|
|
|
|
ImportPatchDecorator.apply_patches()
|