[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)
Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>
### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints:
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -66,7 +66,7 @@ from vllm_ascend.distributed.kv_transfer.utils.utils import (
|
||||
kv_alltoall_and_rearrange,
|
||||
parallel_info,
|
||||
)
|
||||
from vllm_ascend.utils import npu_stream_switch
|
||||
from vllm_ascend.utils import npu_stream_switch, trans_nd_to_nz
|
||||
|
||||
# isort: off
|
||||
if TYPE_CHECKING:
|
||||
@@ -124,6 +124,9 @@ class SendTask:
|
||||
# pd_head_ratio > 1 use
|
||||
k_cache: torch.Tensor | None = None
|
||||
v_cache: torch.Tensor | None = None
|
||||
# kv cache quantization layer use
|
||||
k_quant_cache: torch.Tensor | None = None
|
||||
v_quant_cache: torch.Tensor | None = None
|
||||
layer_idx: int = 0
|
||||
layer_name: str = ""
|
||||
# trans block info
|
||||
@@ -210,6 +213,9 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
use_mla: bool,
|
||||
k_buffer: torch.Tensor,
|
||||
v_buffer: torch.Tensor,
|
||||
enable_kv_quant: bool,
|
||||
k_quant_buffer: torch.Tensor | None,
|
||||
v_quant_buffer: torch.Tensor | None,
|
||||
resharding_stream: torch.npu.Stream,
|
||||
callback_func: Callable[..., None] = lambda x: None,
|
||||
):
|
||||
@@ -232,6 +238,9 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.send_queue = queue.Queue[SendTask]()
|
||||
self.k_buffer = k_buffer
|
||||
self.v_buffer = v_buffer
|
||||
self.enable_kv_quant = enable_kv_quant
|
||||
self.k_quant_buffer = k_quant_buffer
|
||||
self.v_quant_buffer = v_quant_buffer
|
||||
self.ready_event = ready_event
|
||||
self.callback_func = callback_func
|
||||
|
||||
@@ -325,19 +334,43 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous(
|
||||
remote_block_ids, local_block_ids
|
||||
)
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
|
||||
):
|
||||
block_len = block_lens[k]
|
||||
for group_remote_block_id, group_local_block_id in zip(
|
||||
grouped_remote_block_ids, grouped_local_block_ids
|
||||
# kv cache quantization scenario
|
||||
if self.enable_kv_quant and send_task.k_quant_cache is not None:
|
||||
assert len(block_lens) == 2, "Quantization block length must be 2!"
|
||||
quant_block_lens = [block_lens[0] // 2, block_lens[1]]
|
||||
layer_local_quant_kv_addr = [self.k_quant_buffer.data_ptr(), self.v_quant_buffer.data_ptr()]
|
||||
rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx]
|
||||
# eg:[5,6,7,9] -> {5:0, 6:1, 7:2, 9:3}
|
||||
rearrange_block_dict = {
|
||||
value: index
|
||||
for index, value in enumerate(rearrange_block_ids) # type:ignore
|
||||
}
|
||||
for block_len, src_layer_base_addr, dst_layer_base_addr in zip(
|
||||
quant_block_lens, layer_local_quant_kv_addr, layer_remote_kv_base_addr
|
||||
):
|
||||
src = src_layer_base_addr + group_local_block_id[0] * block_len
|
||||
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
|
||||
length = len(group_local_block_id) * block_len
|
||||
src_list.append(src)
|
||||
dst_list.append(dst)
|
||||
length_list.append(length)
|
||||
for group_remote_block_id, group_local_block_id in zip(
|
||||
grouped_remote_block_ids, grouped_local_block_ids
|
||||
):
|
||||
src = src_layer_base_addr + rearrange_block_dict[group_local_block_id[0]] * block_len
|
||||
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
|
||||
length = len(group_local_block_id) * block_len
|
||||
src_list.append(src)
|
||||
dst_list.append(dst)
|
||||
length_list.append(length)
|
||||
else:
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
|
||||
):
|
||||
block_len = block_lens[k]
|
||||
for group_remote_block_id, group_local_block_id in zip(
|
||||
grouped_remote_block_ids, grouped_local_block_ids
|
||||
):
|
||||
src = src_layer_base_addr + group_local_block_id[0] * block_len
|
||||
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
|
||||
length = len(group_local_block_id) * block_len
|
||||
src_list.append(src)
|
||||
dst_list.append(dst)
|
||||
length_list.append(length)
|
||||
else:
|
||||
rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx]
|
||||
rearrange_block_dict = {
|
||||
@@ -380,6 +413,14 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
value = value.view(-1, key.shape[-1]) # type:ignore
|
||||
self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] ->
|
||||
self.v_buffer[: value.shape[0]].copy_(value)
|
||||
if send_task.k_quant_cache is not None:
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
key_quant = send_task.k_quant_cache
|
||||
key_quant = key_quant.view(-1, key_quant.shape[-1]) # type:ignore
|
||||
self.k_quant_buffer[: key_quant.shape[0]].copy_(key_quant)
|
||||
value_quant = send_task.v_quant_cache
|
||||
value_quant = value_quant.view(-1, value_quant.shape[-1]) # type:ignore
|
||||
self.v_quant_buffer[: value_quant.shape[0]].copy_(value_quant)
|
||||
|
||||
# Merge transmission tasks of the same session
|
||||
session_meta: dict[str, TransferMeta] = {}
|
||||
@@ -395,7 +436,9 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
session_meta[session_id].length.extend(length_list)
|
||||
session_meta[session_id].req_ids.append(req_id)
|
||||
|
||||
if self.pd_head_ratio == 1:
|
||||
if send_task.k_quant_cache is not None:
|
||||
self.resharding_stream.synchronize()
|
||||
elif self.pd_head_ratio == 1:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
@@ -628,7 +671,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self.connector_worker.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
|
||||
self, layer_name: str, kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", **kwargs
|
||||
) -> None:
|
||||
"""MooncakeLayerwiseConnector does not save explicitly."""
|
||||
assert self.connector_worker is not None
|
||||
@@ -962,10 +1005,13 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.layer_metadata: dict[str, LayerMetadata] = {}
|
||||
self.attn_resharding_group_idx = set[int]()
|
||||
|
||||
self.enable_kv_quant = (
|
||||
vllm_config.quant_config.enable_fa_quant if vllm_config.quant_config is not None else False
|
||||
)
|
||||
self.pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
self.num_head_replica = get_ascend_config().num_head_replica
|
||||
self.resharding_stream = None
|
||||
if self.pd_head_ratio > 1:
|
||||
if self.pd_head_ratio > 1 or self.enable_kv_quant:
|
||||
self.resharding_stream = torch.npu.Stream()
|
||||
|
||||
self.remote_poller = zmq.Poller() # type: ignore
|
||||
@@ -985,11 +1031,16 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.timeout = 1.0 # seconds
|
||||
self.k_buffer: torch.Tensor | None = None
|
||||
self.v_buffer: torch.Tensor | None = None
|
||||
# TODO(kunpengW-code): Reuse k_buffer, v_buffer
|
||||
self.k_quant_buffer: torch.Tensor | None = None
|
||||
self.v_quant_buffer: torch.Tensor | None = None
|
||||
|
||||
def create_kv_buffer(self, first_kv_cache):
|
||||
def create_kv_buffer(self, first_kv_cache_tuple):
|
||||
alignment = 2 * 1024 * 1024
|
||||
buffer_list = []
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
if self.pd_head_ratio > 1:
|
||||
# regesit kv buffer for tp inequal
|
||||
alignment = 2 * 1024 * 1024
|
||||
self.k_buffer = torch.zeros(
|
||||
first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device
|
||||
)
|
||||
@@ -1002,18 +1053,34 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view(
|
||||
-1, first_kv_cache.shape[-1]
|
||||
)
|
||||
buffer_list.append(self.k_buffer)
|
||||
buffer_list.append(self.v_buffer)
|
||||
if self.enable_kv_quant:
|
||||
quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2
|
||||
self.k_quant_buffer = torch.zeros(
|
||||
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
|
||||
)
|
||||
self.k_quant_buffer = align_memory(self.k_quant_buffer, alignment)[:quant_k_cache_numel].view(
|
||||
-1, first_kv_cache.shape[-1]
|
||||
)
|
||||
quant_v_cache_numel = first_kv_cache_tuple[1].numel()
|
||||
self.v_quant_buffer = torch.zeros(
|
||||
quant_v_cache_numel + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device
|
||||
)
|
||||
self.v_quant_buffer = align_memory(self.v_quant_buffer, alignment)[:quant_v_cache_numel].view(
|
||||
-1, first_kv_cache_tuple[1].shape[-1]
|
||||
)
|
||||
buffer_list.append(self.k_quant_buffer)
|
||||
buffer_list.append(self.v_quant_buffer)
|
||||
|
||||
for tensor in (self.k_buffer, self.v_buffer):
|
||||
assert tensor.data_ptr() % alignment == 0, (
|
||||
"The address of the registered kv cache should be aligned to 2M"
|
||||
)
|
||||
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
|
||||
logger.info(
|
||||
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} "
|
||||
f"{tensor.numel()} {ret_value=}"
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed. ")
|
||||
for tensor in buffer_list:
|
||||
assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
|
||||
logger.info(
|
||||
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed. ")
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data."""
|
||||
@@ -1042,8 +1109,8 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
ptrs = []
|
||||
lengths = []
|
||||
use_resharding_buffer = False
|
||||
resharding_buffer = None
|
||||
use_kv_buffer = False
|
||||
kv_buffer = None
|
||||
for layer_name, kv_cache_tuple in kv_caches.items():
|
||||
if isinstance(kv_cache_tuple, (list, tuple)) is False:
|
||||
kv_cache_tuple = [kv_cache_tuple]
|
||||
@@ -1051,12 +1118,13 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
|
||||
if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))):
|
||||
if (
|
||||
self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec)))
|
||||
) or self.enable_kv_quant:
|
||||
self.attn_resharding_group_idx.add(layer_kv_group_id)
|
||||
if use_resharding_buffer is False:
|
||||
use_resharding_buffer = True
|
||||
resharding_buffer = kv_cache_tuple[0]
|
||||
self.resharding_stream = torch.npu.Stream()
|
||||
if use_kv_buffer is False:
|
||||
use_kv_buffer = True
|
||||
kv_buffer = kv_cache_tuple
|
||||
single_layer_meta = LayerMetadata([], [], [], [])
|
||||
for single_kv_cache in kv_cache_tuple:
|
||||
block_start_rank = 1
|
||||
@@ -1092,8 +1160,8 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
lengths.append(kv_cache_tensor.size)
|
||||
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
if use_resharding_buffer:
|
||||
self.create_kv_buffer(resharding_buffer)
|
||||
if use_kv_buffer:
|
||||
self.create_kv_buffer(kv_buffer)
|
||||
|
||||
num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1
|
||||
mtp_layer_name = ""
|
||||
@@ -1133,6 +1201,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
use_mla=self.use_mla,
|
||||
k_buffer=self.k_buffer,
|
||||
v_buffer=self.v_buffer,
|
||||
enable_kv_quant=self.enable_kv_quant,
|
||||
k_quant_buffer=self.k_quant_buffer,
|
||||
v_quant_buffer=self.v_quant_buffer,
|
||||
resharding_stream=self.resharding_stream,
|
||||
callback_func=self.send_done_send_signal,
|
||||
)
|
||||
@@ -1380,7 +1451,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
metadata.requests[req_id] = update_metadata[req_id]
|
||||
|
||||
# update send task trans block info
|
||||
if self.pd_head_ratio != 1:
|
||||
if self.pd_head_ratio != 1 or self.enable_kv_quant:
|
||||
send_task = metadata.send_task
|
||||
send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)]
|
||||
send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)]
|
||||
@@ -1388,7 +1459,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)]
|
||||
send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)]
|
||||
send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)]
|
||||
device = self.k_buffer.device # type: ignore
|
||||
device = self.k_buffer.device if self.k_buffer is not None else self.k_quant_buffer.device # type: ignore
|
||||
for i in self.attn_resharding_group_idx:
|
||||
send_task.group_rearrange_block_ids[i].extend(
|
||||
sorted(
|
||||
@@ -1415,7 +1486,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: tuple[torch.Tensor, torch.Tensor],
|
||||
kv_layer: list[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
connector_metadata: MooncakeLayerwiseConnectorMetadata,
|
||||
**kwargs,
|
||||
@@ -1490,12 +1561,51 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
|
||||
|
||||
quant_keys = None
|
||||
quant_values = None
|
||||
if self.enable_kv_quant and self.current_layer in self.vllm_config.quant_config.kvcache_quant_layers:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
device = self.k_quant_buffer.device # type: ignore
|
||||
layer = self.vllm_config.compilation_config.static_forward_context[layer_name]
|
||||
# Initialize buffers
|
||||
# [num_tokens, kv_head, head_dim]
|
||||
quant_key = torch.empty(
|
||||
(send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]),
|
||||
dtype=kv_layer[0].dtype,
|
||||
device=device,
|
||||
)
|
||||
quant_values = torch.empty(
|
||||
(send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]),
|
||||
dtype=kv_layer[1].dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Load cache data into buffers
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
kv_layer[0],
|
||||
kv_layer[1],
|
||||
send_task.group_block_table[layer_group_idx],
|
||||
send_task.group_block_len_tensor[layer_group_idx],
|
||||
seq_starts=send_task.group_seq_start_tensor[layer_group_idx],
|
||||
key=quant_key,
|
||||
value=quant_values,
|
||||
)
|
||||
quant_keys = torch.ops.vllm.quantize(
|
||||
quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset
|
||||
)
|
||||
quant_keys = self.get_nz_cache(quant_keys, layer_group_idx)
|
||||
quant_values = self.get_nz_cache(quant_values, layer_group_idx)
|
||||
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
layer_send_task = SendTask(
|
||||
wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
k_quant_cache=quant_keys,
|
||||
v_quant_cache=quant_values,
|
||||
layer_idx=self.current_layer,
|
||||
layer_name=layer_name,
|
||||
group_rearrange_block_ids=send_task.group_rearrange_block_ids,
|
||||
@@ -1510,6 +1620,15 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.kv_send_layer_thread.send_queue.put(layer_send_task)
|
||||
self.current_layer += 1
|
||||
|
||||
# NOTE: Due to the FIA operator constraints, the expected kv cache is ND format, NZ shape,
|
||||
# while the npu_format_cast method only modifies the memory layout, we manually convert it to NZ shape here
|
||||
def get_nz_cache(self, cache_tensor: torch.Tensor, layer_group_idx: int):
|
||||
head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1]
|
||||
cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim)
|
||||
cache_tensor = trans_nd_to_nz(cache_tensor)
|
||||
cache_tensor = cache_tensor.reshape(-1, head_num, head_dim)
|
||||
return cache_tensor
|
||||
|
||||
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
|
||||
"""Get a socket to the remote host."""
|
||||
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)
|
||||
|
||||
Reference in New Issue
Block a user