[Feature] Support kv nz feature for DeepSeek decode node in disagg-prefill scenario (#3072)
By converting the KV cache from ND to NZ format when the decode node
receives it, this PR ensures that the KV NZ feature works correctly
during the decoding phase in disagg-prefill scenario.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: ghphotoframe <854746559@qq.com>
Co-authored-by: alex101-ops <alex1015718386@gmail.com>
This commit is contained in:
@@ -412,8 +412,10 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
logger.debug(
|
||||
f"Finished transferring KV cache for request {request_id}.")
|
||||
except Exception as e:
|
||||
logger.error("Failed to transfer KV cache for request "
|
||||
f"{request_id}: {e}")
|
||||
logger.error(
|
||||
"Failed to transfer KV cache for request "
|
||||
f"{request_id}: {e}",
|
||||
exc_info=True)
|
||||
finally:
|
||||
# Always send the done signal to the remote host to ensure proper
|
||||
# resource cleanup. Failing to do so may cause a memory leak on the
|
||||
@@ -539,97 +541,116 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
|
||||
get_ip(), self.tp_rank, session_id)
|
||||
|
||||
# Determine if the current position is the offset position at the end of the KV transmission.
|
||||
# Determine if the current position is the offset position at the end of
|
||||
# the KV transmission.
|
||||
is_kv_transfer_end = (
|
||||
global_offset == tp_num_need_pulls * self._prefill_pp_size - 1)
|
||||
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
|
||||
# need_nz_cache maybe caused error in non-MLA models
|
||||
if need_cat_cache:
|
||||
self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls)
|
||||
need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end
|
||||
if need_nz_cache or need_cat_cache:
|
||||
self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls,
|
||||
need_cat_cache, need_nz_cache)
|
||||
|
||||
def _cat_kv_cache(self, block_ids: list[list[int]],
|
||||
tp_num_need_pulls: int):
|
||||
def reformat_kv_cache(self,
|
||||
block_ids: list[list[int]],
|
||||
tp_num_need_pulls: int,
|
||||
need_cat_cache: bool = False,
|
||||
need_nz_cache: bool = False):
|
||||
# Get necessary parameters
|
||||
k_cache = list(self.kv_caches.values())[0][0]
|
||||
dtype = k_cache.dtype
|
||||
device = k_cache.device
|
||||
head_dim = self.model_config.hf_text_config.head_dim
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
num_kv_head = max(
|
||||
self.model_config.hf_text_config.num_key_value_heads //
|
||||
self.tp_size, 1)
|
||||
|
||||
flat_block_ids = [item for sublist in block_ids for item in sublist]
|
||||
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
|
||||
block_ids_tensor = torch.tensor(flat_block_ids,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
num_blocks = len(flat_block_ids)
|
||||
block_len = num_blocks * block_size
|
||||
num_tokens = num_blocks * self.block_size
|
||||
|
||||
# Create device tensors for copy operations
|
||||
block_table = block_ids_tensor.view(1, -1).to(device=device)
|
||||
block_len_tensor = torch.tensor([block_len],
|
||||
dtype=torch.int32).to(device=device)
|
||||
seq_start_tensor = torch.tensor([0],
|
||||
dtype=torch.int32).to(device=device)
|
||||
block_table = block_ids_tensor.view(1, -1)
|
||||
block_len_tensor = torch.tensor([num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device)
|
||||
|
||||
# Initialize buffers
|
||||
k_buffer = torch.empty(block_len,
|
||||
num_kv_head,
|
||||
head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
v_buffer = torch.empty(block_len,
|
||||
num_kv_head,
|
||||
head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_buffer = torch.empty(
|
||||
(num_tokens, self.num_kv_heads, self.k_head_dim),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
v_buffer = torch.empty(
|
||||
(num_tokens, self.num_kv_heads, self.v_head_dim),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# Create slot mapping for reshape operations
|
||||
block_offsets = torch.arange(0, block_size, dtype=torch.int32)
|
||||
block_offsets = torch.arange(0,
|
||||
self.block_size,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
slot_mapping = (block_offsets.reshape(
|
||||
(1, block_size)) + block_ids_tensor.reshape(
|
||||
(num_blocks, 1)) * block_size)
|
||||
slot_mapping = slot_mapping.flatten().to(device=device)
|
||||
(1, self.block_size)) + block_ids_tensor.reshape(
|
||||
(num_blocks, 1)) * self.block_size).flatten()
|
||||
|
||||
# FIXME: Right now, if we skip synchronization at this point, the system
|
||||
# will crash in GQA scenarios. However, we still haven't identified the
|
||||
# root cause.
|
||||
torch.npu.synchronize()
|
||||
|
||||
# Process each layer in the KV cache
|
||||
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
|
||||
# Load cache data into buffers
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
k_cache_layer,
|
||||
v_cache_layer,
|
||||
block_table,
|
||||
block_len_tensor,
|
||||
seq_starts=seq_start_tensor,
|
||||
key=k_buffer,
|
||||
value=v_buffer,
|
||||
)
|
||||
|
||||
# Transpose KV cache
|
||||
k_buffer = self._transpose_kv_cache_between_head(
|
||||
k_buffer, num_blocks, block_size, block_len, num_kv_head,
|
||||
tp_num_need_pulls)
|
||||
v_buffer = self._transpose_kv_cache_between_head(
|
||||
v_buffer, num_blocks, block_size, block_len, num_kv_head,
|
||||
tp_num_need_pulls)
|
||||
|
||||
# Reshape and cache the processed buffers
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=k_buffer,
|
||||
value=v_buffer,
|
||||
key_cache=k_cache_layer,
|
||||
value_cache=v_cache_layer,
|
||||
slot_indices=slot_mapping,
|
||||
)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(k_cache_layer,
|
||||
v_cache_layer,
|
||||
block_table,
|
||||
block_len_tensor,
|
||||
seq_starts=seq_start_tensor,
|
||||
key=k_buffer,
|
||||
value=v_buffer)
|
||||
if need_cat_cache:
|
||||
self._cat_kv_cache(k_cache_layer, v_cache_layer, k_buffer,
|
||||
v_buffer, tp_num_need_pulls, num_blocks,
|
||||
num_tokens, slot_mapping)
|
||||
if need_nz_cache:
|
||||
self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer,
|
||||
v_buffer, slot_mapping)
|
||||
# Clean up buffers
|
||||
del k_buffer, v_buffer
|
||||
|
||||
def _transpose_kv_cache_between_head(
|
||||
self, buffer: torch.Tensor, num_blocks: int, block_size: int,
|
||||
block_len: int, num_kv_head: int,
|
||||
tp_num_need_pulls: int) -> torch.Tensor:
|
||||
buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1)
|
||||
buffer.transpose_(1, 2)
|
||||
return buffer.contiguous().view(block_len, num_kv_head, -1)
|
||||
def _cat_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer,
|
||||
tp_num_need_pulls, num_blocks, num_tokens, slot_mapping):
|
||||
|
||||
def _transpose_kv_cache_between_head(
|
||||
buffer: torch.Tensor) -> torch.Tensor:
|
||||
buffer = buffer.view(num_blocks, tp_num_need_pulls,
|
||||
self.block_size, -1)
|
||||
buffer.transpose_(1, 2)
|
||||
return buffer.contiguous().view(num_tokens, self.num_kv_heads, -1)
|
||||
|
||||
# Transpose KV cache
|
||||
k_buffer = _transpose_kv_cache_between_head(k_buffer)
|
||||
v_buffer = _transpose_kv_cache_between_head(v_buffer)
|
||||
|
||||
# Reshape and cache the processed buffers
|
||||
torch_npu._npu_reshape_and_cache(key=k_buffer,
|
||||
value=v_buffer,
|
||||
key_cache=k_cache_layer,
|
||||
value_cache=v_cache_layer,
|
||||
slot_indices=slot_mapping)
|
||||
|
||||
def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer,
|
||||
slot_mapping):
|
||||
nz_fmt_last_dim = 16
|
||||
k_cache_layer = k_cache_layer.view(
|
||||
-1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim,
|
||||
self.block_size, nz_fmt_last_dim)
|
||||
v_cache_layer = v_cache_layer.view(
|
||||
-1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim,
|
||||
self.block_size, nz_fmt_last_dim)
|
||||
torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer,
|
||||
v_cache_layer, slot_mapping)
|
||||
|
||||
def _get_remote_metadata(self, remote_host: str,
|
||||
remote_handshake_port: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user