[DeepSeek-V3.2] Include indexer kv cache when estimating kv cache size (#11309)
This commit is contained in:
@@ -1177,7 +1177,9 @@ class MLATokenToKVPool(KVCache):
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
)
|
||||
self._finalize_allocation_log(size)
|
||||
if not use_nsa:
|
||||
# NSA will allocate indexer KV cache later and then log the total size
|
||||
self._finalize_allocation_log(size)
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
assert hasattr(self, "kv_buffer")
|
||||
@@ -1298,6 +1300,9 @@ class MLATokenToKVPool(KVCache):
|
||||
|
||||
|
||||
class NSATokenToKVPool(MLATokenToKVPool):
|
||||
quant_block_size = 128
|
||||
index_k_with_scale_buffer_dtype = torch.uint8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@@ -1331,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
||||
# num head == 1 and head dim == 128 for index_k in NSA
|
||||
assert index_head_dim == 128
|
||||
|
||||
self.quant_block_size = 128
|
||||
|
||||
assert self.page_size == 64
|
||||
self.index_k_with_scale_buffer = [
|
||||
torch.zeros(
|
||||
@@ -1347,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
||||
self.page_size
|
||||
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
dtype=self.index_k_with_scale_buffer_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self._finalize_allocation_log(size)
|
||||
|
||||
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
if self.layer_transfer_counter is not None:
|
||||
@@ -1393,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
||||
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
||||
)
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
kv_size_bytes = super().get_kv_size_bytes()
|
||||
for index_k_cache in self.index_k_with_scale_buffer:
|
||||
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
||||
return kv_size_bytes
|
||||
|
||||
|
||||
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
def __init__(
|
||||
|
||||
@@ -1280,6 +1280,17 @@ class ModelRunner:
|
||||
* num_layers
|
||||
* torch._utils._element_size(self.kv_cache_dtype)
|
||||
)
|
||||
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
|
||||
if is_deepseek_nsa(self.model_config.hf_config):
|
||||
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
|
||||
indexer_size_per_token = (
|
||||
index_head_dim
|
||||
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
||||
)
|
||||
element_size = torch._utils._element_size(
|
||||
NSATokenToKVPool.index_k_with_scale_buffer_dtype
|
||||
)
|
||||
cell_size += indexer_size_per_token * num_layers * element_size
|
||||
else:
|
||||
cell_size = (
|
||||
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||
|
||||
@@ -863,9 +863,6 @@ class ServerArgs:
|
||||
self.page_size = 64
|
||||
logger.warning("Setting page size to 64 for DeepSeek NSA.")
|
||||
|
||||
self.mem_fraction_static = 0.8
|
||||
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
|
||||
|
||||
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
|
||||
import torch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user