[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints: 
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
pichangping
2026-03-16 22:49:05 +08:00
committed by GitHub
parent a6f6e919e6
commit 3f39ac9c8d
15 changed files with 1112 additions and 161 deletions

View File

@@ -66,7 +66,7 @@ from vllm_ascend.distributed.kv_transfer.utils.utils import (
kv_alltoall_and_rearrange,
parallel_info,
)
from vllm_ascend.utils import npu_stream_switch
from vllm_ascend.utils import npu_stream_switch, trans_nd_to_nz
# isort: off
if TYPE_CHECKING:
@@ -124,6 +124,9 @@ class SendTask:
# pd_head_ratio > 1 use
k_cache: torch.Tensor | None = None
v_cache: torch.Tensor | None = None
# kv cache quantization layer use
k_quant_cache: torch.Tensor | None = None
v_quant_cache: torch.Tensor | None = None
layer_idx: int = 0
layer_name: str = ""
# trans block info
@@ -210,6 +213,9 @@ class KVCacheSendingLayerThread(threading.Thread):
use_mla: bool,
k_buffer: torch.Tensor,
v_buffer: torch.Tensor,
enable_kv_quant: bool,
k_quant_buffer: torch.Tensor | None,
v_quant_buffer: torch.Tensor | None,
resharding_stream: torch.npu.Stream,
callback_func: Callable[..., None] = lambda x: None,
):
@@ -232,6 +238,9 @@ class KVCacheSendingLayerThread(threading.Thread):
self.send_queue = queue.Queue[SendTask]()
self.k_buffer = k_buffer
self.v_buffer = v_buffer
self.enable_kv_quant = enable_kv_quant
self.k_quant_buffer = k_quant_buffer
self.v_quant_buffer = v_quant_buffer
self.ready_event = ready_event
self.callback_func = callback_func
@@ -325,19 +334,43 @@ class KVCacheSendingLayerThread(threading.Thread):
grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous(
remote_block_ids, local_block_ids
)
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
):
block_len = block_lens[k]
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
# kv cache quantization scenario
if self.enable_kv_quant and send_task.k_quant_cache is not None:
assert len(block_lens) == 2, "Quantization block length must be 2!"
quant_block_lens = [block_lens[0] // 2, block_lens[1]]
layer_local_quant_kv_addr = [self.k_quant_buffer.data_ptr(), self.v_quant_buffer.data_ptr()]
rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx]
# eg:[5,6,7,9] -> {5:0, 6:1, 7:2, 9:3}
rearrange_block_dict = {
value: index
for index, value in enumerate(rearrange_block_ids) # type:ignore
}
for block_len, src_layer_base_addr, dst_layer_base_addr in zip(
quant_block_lens, layer_local_quant_kv_addr, layer_remote_kv_base_addr
):
src = src_layer_base_addr + group_local_block_id[0] * block_len
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
length = len(group_local_block_id) * block_len
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
):
src = src_layer_base_addr + rearrange_block_dict[group_local_block_id[0]] * block_len
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
length = len(group_local_block_id) * block_len
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
else:
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
):
block_len = block_lens[k]
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
):
src = src_layer_base_addr + group_local_block_id[0] * block_len
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
length = len(group_local_block_id) * block_len
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
else:
rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx]
rearrange_block_dict = {
@@ -380,6 +413,14 @@ class KVCacheSendingLayerThread(threading.Thread):
value = value.view(-1, key.shape[-1]) # type:ignore
self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] ->
self.v_buffer[: value.shape[0]].copy_(value)
if send_task.k_quant_cache is not None:
with npu_stream_switch(self.resharding_stream):
key_quant = send_task.k_quant_cache
key_quant = key_quant.view(-1, key_quant.shape[-1]) # type:ignore
self.k_quant_buffer[: key_quant.shape[0]].copy_(key_quant)
value_quant = send_task.v_quant_cache
value_quant = value_quant.view(-1, value_quant.shape[-1]) # type:ignore
self.v_quant_buffer[: value_quant.shape[0]].copy_(value_quant)
# Merge transmission tasks of the same session
session_meta: dict[str, TransferMeta] = {}
@@ -395,7 +436,9 @@ class KVCacheSendingLayerThread(threading.Thread):
session_meta[session_id].length.extend(length_list)
session_meta[session_id].req_ids.append(req_id)
if self.pd_head_ratio == 1:
if send_task.k_quant_cache is not None:
self.resharding_stream.synchronize()
elif self.pd_head_ratio == 1:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
@@ -628,7 +671,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.wait_for_layer_load(layer_name)
def save_kv_layer(
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
self, layer_name: str, kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", **kwargs
) -> None:
"""MooncakeLayerwiseConnector does not save explicitly."""
assert self.connector_worker is not None
@@ -962,10 +1005,13 @@ class MooncakeLayerwiseConnectorWorker:
self.layer_metadata: dict[str, LayerMetadata] = {}
self.attn_resharding_group_idx = set[int]()
self.enable_kv_quant = (
vllm_config.quant_config.enable_fa_quant if vllm_config.quant_config is not None else False
)
self.pd_head_ratio = get_ascend_config().pd_head_ratio
self.num_head_replica = get_ascend_config().num_head_replica
self.resharding_stream = None
if self.pd_head_ratio > 1:
if self.pd_head_ratio > 1 or self.enable_kv_quant:
self.resharding_stream = torch.npu.Stream()
self.remote_poller = zmq.Poller() # type: ignore
@@ -985,11 +1031,16 @@ class MooncakeLayerwiseConnectorWorker:
self.timeout = 1.0 # seconds
self.k_buffer: torch.Tensor | None = None
self.v_buffer: torch.Tensor | None = None
# TODO(kunpengW-code): Reuse k_buffer, v_buffer
self.k_quant_buffer: torch.Tensor | None = None
self.v_quant_buffer: torch.Tensor | None = None
def create_kv_buffer(self, first_kv_cache):
def create_kv_buffer(self, first_kv_cache_tuple):
alignment = 2 * 1024 * 1024
buffer_list = []
first_kv_cache = first_kv_cache_tuple[0]
if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal
alignment = 2 * 1024 * 1024
self.k_buffer = torch.zeros(
first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device
)
@@ -1002,18 +1053,34 @@ class MooncakeLayerwiseConnectorWorker:
self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1]
)
buffer_list.append(self.k_buffer)
buffer_list.append(self.v_buffer)
if self.enable_kv_quant:
quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2
self.k_quant_buffer = torch.zeros(
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
)
self.k_quant_buffer = align_memory(self.k_quant_buffer, alignment)[:quant_k_cache_numel].view(
-1, first_kv_cache.shape[-1]
)
quant_v_cache_numel = first_kv_cache_tuple[1].numel()
self.v_quant_buffer = torch.zeros(
quant_v_cache_numel + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device
)
self.v_quant_buffer = align_memory(self.v_quant_buffer, alignment)[:quant_v_cache_numel].view(
-1, first_kv_cache_tuple[1].shape[-1]
)
buffer_list.append(self.k_quant_buffer)
buffer_list.append(self.v_quant_buffer)
for tensor in (self.k_buffer, self.v_buffer):
assert tensor.data_ptr() % alignment == 0, (
"The address of the registered kv cache should be aligned to 2M"
)
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} "
f"{tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
for tensor in buffer_list:
assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
@@ -1042,8 +1109,8 @@ class MooncakeLayerwiseConnectorWorker:
ptrs = []
lengths = []
use_resharding_buffer = False
resharding_buffer = None
use_kv_buffer = False
kv_buffer = None
for layer_name, kv_cache_tuple in kv_caches.items():
if isinstance(kv_cache_tuple, (list, tuple)) is False:
kv_cache_tuple = [kv_cache_tuple]
@@ -1051,12 +1118,13 @@ class MooncakeLayerwiseConnectorWorker:
layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))):
if (
self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec)))
) or self.enable_kv_quant:
self.attn_resharding_group_idx.add(layer_kv_group_id)
if use_resharding_buffer is False:
use_resharding_buffer = True
resharding_buffer = kv_cache_tuple[0]
self.resharding_stream = torch.npu.Stream()
if use_kv_buffer is False:
use_kv_buffer = True
kv_buffer = kv_cache_tuple
single_layer_meta = LayerMetadata([], [], [], [])
for single_kv_cache in kv_cache_tuple:
block_start_rank = 1
@@ -1092,8 +1160,8 @@ class MooncakeLayerwiseConnectorWorker:
lengths.append(kv_cache_tensor.size)
global_te.register_buffer(ptrs, lengths)
if use_resharding_buffer:
self.create_kv_buffer(resharding_buffer)
if use_kv_buffer:
self.create_kv_buffer(kv_buffer)
num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1
mtp_layer_name = ""
@@ -1133,6 +1201,9 @@ class MooncakeLayerwiseConnectorWorker:
use_mla=self.use_mla,
k_buffer=self.k_buffer,
v_buffer=self.v_buffer,
enable_kv_quant=self.enable_kv_quant,
k_quant_buffer=self.k_quant_buffer,
v_quant_buffer=self.v_quant_buffer,
resharding_stream=self.resharding_stream,
callback_func=self.send_done_send_signal,
)
@@ -1380,7 +1451,7 @@ class MooncakeLayerwiseConnectorWorker:
metadata.requests[req_id] = update_metadata[req_id]
# update send task trans block info
if self.pd_head_ratio != 1:
if self.pd_head_ratio != 1 or self.enable_kv_quant:
send_task = metadata.send_task
send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)]
send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)]
@@ -1388,7 +1459,7 @@ class MooncakeLayerwiseConnectorWorker:
send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)]
send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)]
send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)]
device = self.k_buffer.device # type: ignore
device = self.k_buffer.device if self.k_buffer is not None else self.k_quant_buffer.device # type: ignore
for i in self.attn_resharding_group_idx:
send_task.group_rearrange_block_ids[i].extend(
sorted(
@@ -1415,7 +1486,7 @@ class MooncakeLayerwiseConnectorWorker:
def save_kv_layer(
self,
layer_name: str,
kv_layer: tuple[torch.Tensor, torch.Tensor],
kv_layer: list[torch.Tensor],
attn_metadata: "AttentionMetadata",
connector_metadata: MooncakeLayerwiseConnectorMetadata,
**kwargs,
@@ -1490,12 +1561,51 @@ class MooncakeLayerwiseConnectorWorker:
values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
quant_keys = None
quant_values = None
if self.enable_kv_quant and self.current_layer in self.vllm_config.quant_config.kvcache_quant_layers:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
device = self.k_quant_buffer.device # type: ignore
layer = self.vllm_config.compilation_config.static_forward_context[layer_name]
# Initialize buffers
# [num_tokens, kv_head, head_dim]
quant_key = torch.empty(
(send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]),
dtype=kv_layer[0].dtype,
device=device,
)
quant_values = torch.empty(
(send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]),
dtype=kv_layer[1].dtype,
device=device,
)
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
kv_layer[0],
kv_layer[1],
send_task.group_block_table[layer_group_idx],
send_task.group_block_len_tensor[layer_group_idx],
seq_starts=send_task.group_seq_start_tensor[layer_group_idx],
key=quant_key,
value=quant_values,
)
quant_keys = torch.ops.vllm.quantize(
quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset
)
quant_keys = self.get_nz_cache(quant_keys, layer_group_idx)
quant_values = self.get_nz_cache(quant_values, layer_group_idx)
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
layer_send_task = SendTask(
wait_event=reshape_cache_event,
k_cache=keys,
v_cache=values,
k_quant_cache=quant_keys,
v_quant_cache=quant_values,
layer_idx=self.current_layer,
layer_name=layer_name,
group_rearrange_block_ids=send_task.group_rearrange_block_ids,
@@ -1510,6 +1620,15 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_send_layer_thread.send_queue.put(layer_send_task)
self.current_layer += 1
# NOTE: Due to the FIA operator constraints, the expected kv cache is ND format, NZ shape,
# while the npu_format_cast method only modifies the memory layout, we manually convert it to NZ shape here
def get_nz_cache(self, cache_tensor: torch.Tensor, layer_group_idx: int):
head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1]
cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim)
cache_tensor = trans_nd_to_nz(cache_tensor)
cache_tensor = cache_tensor.reshape(-1, head_num, head_dim)
return cache_tensor
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
"""Get a socket to the remote host."""
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)