[ModelRunner][Refactor] Refactor kv cache tensor initialization logic (#3106)
### What this PR does / why we need it?
Refactor kv cache tensor initialization logic.
1. Unify the kvcache tensor initialization logic of deepseek and normal
models
2. spilt `initialize_kv_cache_tensors` into `_allocate_kv_cache_tensors`
and `_reshape_kv_cache_tensors`, following gpu modelrunner in vllm
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
1. prefill disaggregation scenario
4. deepseek + aclgraph/eager mode
5. qwen3 next
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -3138,15 +3138,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
])
|
||||
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
|
||||
if self.use_sparse:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||
kv_cache_config)
|
||||
elif self.model_config.is_deepseek_mla:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
|
||||
kv_cache_config)
|
||||
else:
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
@@ -3158,197 +3150,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek_sfa(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in "
|
||||
"NPU.")
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
elif hasattr(
|
||||
attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla and not self.use_sparse:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk, block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
alignment = 2 * 1024 * 1024
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
rope_dim)
|
||||
#### k cache
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
k_cache_shape = (num_blocks, block_size, 1, 128)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
|
||||
#### k cache
|
||||
k_cache = torch.zeros(k_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
k_cache = self._convert_torch_format(k_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
#### k cache
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
k_allocate_shape = num_blocks * block_size * 1 * 128
|
||||
k_allocate_shape_alignment = k_allocate_shape + alignment
|
||||
k_cache = torch.zeros(k_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
|
||||
nope_cache = self._align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||
rope_cache = self._align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||
k_cache = self._align_memory(
|
||||
k_cache,
|
||||
alignment)[:k_allocate_shape].view(k_cache_shape)
|
||||
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek_mla(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in "
|
||||
"NPU.")
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
elif hasattr(attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk, block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
alignment = 2 * 1024 * 1024
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
rope_dim)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = self._align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||
rope_cache = self._align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache)
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
@@ -3360,6 +3161,34 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
NOTE: To support prefill disaggregation, we need to split kvcache tensor into
|
||||
k_cahce and v cache, and the addr of both are aligned by 2M
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names
|
||||
to their corresponding memory buffer for K cache and V cache.
|
||||
"""
|
||||
# init kv cache tensors
|
||||
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
|
||||
Optional[torch.Tensor]]] = {}
|
||||
@@ -3383,39 +3212,88 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
tensor = self._align_memory(
|
||||
tensor, alignment)[:kv_cache_tensor.size]
|
||||
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the kvcache between the linear_attn specs in the same group
|
||||
# shared the kvcache between the self_attn specs in the same group
|
||||
if "linear_attn" in layer_name_inner:
|
||||
kv_cache_raw_tensors[layer_name_inner] = tensor
|
||||
elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
|
||||
):
|
||||
# NOTE: We need to init k cache tensor (nope cache tensor in mla) and
|
||||
# v cache tensor (rope cache tensor in mla) separately to support llmdatadist,
|
||||
# as it only support the 0-dim of kv_cache is `num_blocks`.
|
||||
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
|
||||
# and rope head dim.
|
||||
if self.model_config.is_deepseek_mla:
|
||||
head_size = self.model_config.hf_text_config.qk_rope_head_dim + \
|
||||
self.model_config.hf_text_config.kv_lora_rank
|
||||
|
||||
dsa_k_cache_factor = None
|
||||
dsa_k_cache_size = None
|
||||
if not self.model_config.is_deepseek_mla:
|
||||
# for non-mla model, use FullAttentionSpec
|
||||
k_tensor_split_factor = 2
|
||||
v_tensor_split_factor = 2
|
||||
elif self.use_sparse:
|
||||
# for deepseek v3.2, DSA use FullAttentionSpec
|
||||
# FullAttentionSpec allocate 2 * mla page size bytes,
|
||||
# and we use half of that for k cache in DSA
|
||||
dsa_k_cache_factor = 2
|
||||
k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank
|
||||
v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim
|
||||
dsa_k_cache_size = int(kv_cache_tensor.size //
|
||||
dsa_k_cache_factor)
|
||||
else:
|
||||
# for other deepseek models, use MLAAttentionSpec
|
||||
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
|
||||
v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim
|
||||
|
||||
k_tensor_size = int(kv_cache_tensor.size //
|
||||
k_tensor_split_factor)
|
||||
v_tensor_size = int(kv_cache_tensor.size //
|
||||
v_tensor_split_factor)
|
||||
|
||||
# for other attentions, e.g., self_attn, sliding window attn
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
||||
k_tensor = torch.zeros(k_tensor_size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
||||
v_tensor = torch.zeros(v_tensor_size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(
|
||||
dsa_k_cache_size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
else:
|
||||
cache_size = kv_cache_tensor.size // 2
|
||||
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
|
||||
k_tensor = torch.zeros(cache_size_aligned,
|
||||
k_tensor = torch.zeros(k_tensor_size + alignment,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
v_tensor = torch.zeros(cache_size_aligned,
|
||||
v_tensor = torch.zeros(v_tensor_size + alignment,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
k_tensor = self._align_memory(k_tensor,
|
||||
alignment)[:cache_size]
|
||||
v_tensor = self._align_memory(v_tensor,
|
||||
alignment)[:cache_size]
|
||||
k_tensor = self._align_memory(
|
||||
k_tensor, alignment)[:k_tensor_size]
|
||||
v_tensor = self._align_memory(
|
||||
v_tensor, alignment)[:v_tensor_size]
|
||||
#### k cache: for deepseek sparse attention
|
||||
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
|
||||
dsa_k_cache_tensor = torch.zeros(
|
||||
dsa_k_cache_size + alignment,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
dsa_k_cache_tensor = self._align_memory(
|
||||
dsa_k_cache_tensor,
|
||||
alignment)[:dsa_k_cache_size]
|
||||
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the kvcache between the self_attn specs in the same group
|
||||
if ("attn" in layer_name_inner
|
||||
and "linear_attn" not in layer_name_inner):
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor,
|
||||
v_tensor)
|
||||
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) if \
|
||||
not self.use_sparse else (k_tensor, v_tensor, dsa_k_cache_tensor)
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
@@ -3426,6 +3304,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
@@ -3437,14 +3333,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
raw_dsa_k_tensor = None
|
||||
if self.use_sparse:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
assert raw_dsa_k_tensor is not None
|
||||
sum_page_size_bytes = raw_k_tensor.numel(
|
||||
) + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
|
||||
else:
|
||||
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
layer_name]
|
||||
sum_page_size_bytes = raw_k_tensor.numel(
|
||||
) + raw_v_tensor.numel()
|
||||
assert raw_k_tensor is not None
|
||||
assert raw_v_tensor is not None
|
||||
assert (raw_k_tensor.numel() + raw_v_tensor.numel()
|
||||
) % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel()
|
||||
) // kv_cache_spec.page_size_bytes
|
||||
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
|
||||
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||
@@ -3476,11 +3380,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:])
|
||||
if not self.model_config.is_deepseek_mla:
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
else:
|
||||
# k_cache: nope_cache v_cache: rope_cache
|
||||
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
|
||||
k_shape = [
|
||||
mla_num_blocks, mla_block_size, num_kv_heads,
|
||||
self.model_config.hf_text_config.kv_lora_rank
|
||||
]
|
||||
v_shape = [
|
||||
mla_num_blocks, mla_block_size, num_kv_heads,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim
|
||||
]
|
||||
k_cache = raw_k_tensor.view(dtype).view(k_shape)
|
||||
k_cache = self._convert_torch_format(k_cache)
|
||||
v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:])
|
||||
v_cache = raw_v_tensor.view(dtype).view(v_shape)
|
||||
v_cache = self._convert_torch_format(v_cache)
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
if self.use_sparse and raw_dsa_k_tensor is not None:
|
||||
dsa_k_cache_shape = (num_blocks,
|
||||
kv_cache_spec.block_size, 1, 128)
|
||||
dsa_k_cache_size = (
|
||||
num_blocks
|
||||
) * kv_cache_spec.block_size * 128 * dtype.itemsize
|
||||
dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(
|
||||
dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
else:
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor is not None
|
||||
@@ -3521,10 +3449,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
|
||||
Reference in New Issue
Block a user