Files
xc-llm-ascend/vllm_ascend/distributed/kv_transfer/simple_buffer.py
wangxiyuan 6193ba679b [CI] add codespell CI and fix format.sh (#827)
1. Fix format check error to make format.sh work
2. Add codespell check CI 
3. Add the missing required package for vllm-ascend.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-05-12 22:04:48 +08:00

210 lines
7.5 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import zlib
from typing import List, Optional
import llm_datadist # type: ignore
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
KVLookupBufferBase
from vllm.logger import init_logger
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
from vllm_ascend.distributed.kv_transfer.utils import TORCH_DTYPE_TO_NPU_DTYPE
logger = init_logger(__name__)
# Hash a string into a int32 value.
def int32_hash(data):
assert isinstance(data, str)
data = data.encode("utf-8")
return zlib.adler32(data)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, data_pipe: SimplePipe):
self.data_pipe = data_pipe
# Consumer buffer need these information to construct receiving buffer.
self.num_layers = None
self.num_heads = None
self.head_size = None
self.dtype = None
self.hidden_size = None
self.key_buffer = None
self.value_buffer = None
self.hidden_buffer = None
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
req_id: str,
) -> None:
"""
seq_len: num_tokens of current request.
input_tokens: [seq_len]
roi: [seq_len]
key: [num_layers, seq_len, num_kv_heads, head_size]
value: [num_layers, seq_len, num_kv_heads, head_size]
hidden: [seq_len, hidden_size]
"""
orig_k_shape = key.shape
num_layers = orig_k_shape[0]
# unsequeeze all tensors to make first dim to 1.
# This is because D node can only pull one batch data from P.
# So we make first dim to 1 here in order to pull full data.
key = key.view(num_layers, -1).unsqueeze(0)
value = value.view(num_layers, -1).unsqueeze(0)
hidden = hidden.unsqueeze(0)
hidden_dtype = key.dtype
# initialize LLMDatadist data structure
key_desc = llm_datadist.CacheDesc(
1,
key.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=1,
)
value_desc = llm_datadist.CacheDesc(
1,
value.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=1,
)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden.shape,
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
seq_len_dim_index=-1,
)
req_id = int32_hash(req_id)
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 1)
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 2)
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 3)
# Currently we use hash value of request id as key, so no need to send input_tokens
self.key_buffer = self.data_pipe.send_tensor(key, key_desc,
key_cache_key)
self.value_buffer = self.data_pipe.send_tensor(value, value_desc,
value_cache_key)
self.hidden_buffer = self.data_pipe.send_tensor(
hidden, hidden_desc, hidden_cache_key)
def drop_select(
self,
input_tokens: torch.Tensor,
roi: Optional[torch.Tensor],
req_id: str,
) -> List[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
A list of tensors including:
key: [num_layers, num_tokens, num_heads, head_size]
value: [num_layers, num_tokens, num_heads, head_size]
hidden_or_intermediate_states: [num_tokens, hidden_size]
roi: None (Currently we don't supported roi)
"""
orig_req_id = req_id
req_id = int32_hash(req_id)
num_tokens = input_tokens.shape[0]
kv_shape = (
1,
self.num_layers,
num_tokens * self.num_heads * self.head_size,
)
hidden_shape = (1, num_tokens, self.hidden_size)
key_desc = llm_datadist.CacheDesc(
1,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
value_desc = llm_datadist.CacheDesc(
1,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden_shape,
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
seq_len_dim_index=-1,
)
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 1)
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 2)
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
req_id, 3)
# Deallocate buffer allocated in last round.
if self.key_buffer:
try:
self.data_pipe.deallocate_buffer(self.key_buffer)
self.data_pipe.deallocate_buffer(self.value_buffer)
self.data_pipe.deallocate_buffer(self.hidden_buffer)
except Exception as e:
logger.warning(
f"Failed to free kv cache buffer, Error code: {str(e)}")
try:
self.key_buffer, key = self.data_pipe.recv_tensor(
key_desc, key_cache_key)
self.value_buffer, value = self.data_pipe.recv_tensor(
value_desc, value_cache_key)
self.hidden_buffer, hidden = self.data_pipe.recv_tensor(
hidden_desc, hidden_cache_key)
key = key.view(self.num_layers, num_tokens, self.num_heads,
self.head_size)
value = value.view(self.num_layers, num_tokens, self.num_heads,
self.head_size)
hidden = hidden.view(num_tokens, self.hidden_size)
except Exception as e:
logger.warning(
f"Fail to receive kv cache and hidden states of request: {orig_req_id} "
f"Error is {str(e)}")
return [None, None, None, None]
return [key, value, hidden, roi]
def close(self):
pass