support deepseek quant & mix-parallel with graphmode (#585)
### What this PR does / why we need it? 1. support deepseek with w8a8 quant; 2. support deepseek with mix-parallel(multi-DP, EP+TP); 3. support deepseek with graphmode. --------- Signed-off-by: wen-jie666 <wenjie39@huawei.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
@@ -17,53 +17,66 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
|
||||
def allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[Tuple]:
|
||||
) -> List[Any]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[Tuple] = []
|
||||
kv_cache: List[Any] = []
|
||||
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
alloc_shape = kv_cache_shape
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config and additional_config.get("enable_graph_mode", False):
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
alloc_shape = kv_cache_shape
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
|
||||
else:
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append(layer_kv_cache)
|
||||
return kv_cache
|
||||
|
||||
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
CacheEngine._allocate_kv_cache = allocate_kv_cache
|
||||
CacheEngine._allocate_kv_cache = allocate_kv_cache
|
||||
|
||||
Reference in New Issue
Block a user