[DeepSeek-V3.2] Include indexer kv cache when estimating kv cache size (#11309)
This commit is contained in:
@@ -1177,7 +1177,9 @@ class MLATokenToKVPool(KVCache):
|
|||||||
dtype=torch.uint64,
|
dtype=torch.uint64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self._finalize_allocation_log(size)
|
if not use_nsa:
|
||||||
|
# NSA will allocate indexer KV cache later and then log the total size
|
||||||
|
self._finalize_allocation_log(size)
|
||||||
|
|
||||||
def get_kv_size_bytes(self):
|
def get_kv_size_bytes(self):
|
||||||
assert hasattr(self, "kv_buffer")
|
assert hasattr(self, "kv_buffer")
|
||||||
@@ -1298,6 +1300,9 @@ class MLATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
|
|
||||||
class NSATokenToKVPool(MLATokenToKVPool):
|
class NSATokenToKVPool(MLATokenToKVPool):
|
||||||
|
quant_block_size = 128
|
||||||
|
index_k_with_scale_buffer_dtype = torch.uint8
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
@@ -1331,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|||||||
# num head == 1 and head dim == 128 for index_k in NSA
|
# num head == 1 and head dim == 128 for index_k in NSA
|
||||||
assert index_head_dim == 128
|
assert index_head_dim == 128
|
||||||
|
|
||||||
self.quant_block_size = 128
|
|
||||||
|
|
||||||
assert self.page_size == 64
|
assert self.page_size == 64
|
||||||
self.index_k_with_scale_buffer = [
|
self.index_k_with_scale_buffer = [
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
@@ -1347,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|||||||
self.page_size
|
self.page_size
|
||||||
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
||||||
),
|
),
|
||||||
dtype=torch.uint8,
|
dtype=self.index_k_with_scale_buffer_dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
self._finalize_allocation_log(size)
|
||||||
|
|
||||||
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
@@ -1393,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|||||||
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_kv_size_bytes(self):
|
||||||
|
kv_size_bytes = super().get_kv_size_bytes()
|
||||||
|
for index_k_cache in self.index_k_with_scale_buffer:
|
||||||
|
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
||||||
|
return kv_size_bytes
|
||||||
|
|
||||||
|
|
||||||
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1280,6 +1280,17 @@ class ModelRunner:
|
|||||||
* num_layers
|
* num_layers
|
||||||
* torch._utils._element_size(self.kv_cache_dtype)
|
* torch._utils._element_size(self.kv_cache_dtype)
|
||||||
)
|
)
|
||||||
|
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
|
||||||
|
if is_deepseek_nsa(self.model_config.hf_config):
|
||||||
|
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
|
||||||
|
indexer_size_per_token = (
|
||||||
|
index_head_dim
|
||||||
|
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
||||||
|
)
|
||||||
|
element_size = torch._utils._element_size(
|
||||||
|
NSATokenToKVPool.index_k_with_scale_buffer_dtype
|
||||||
|
)
|
||||||
|
cell_size += indexer_size_per_token * num_layers * element_size
|
||||||
else:
|
else:
|
||||||
cell_size = (
|
cell_size = (
|
||||||
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||||
|
|||||||
@@ -863,9 +863,6 @@ class ServerArgs:
|
|||||||
self.page_size = 64
|
self.page_size = 64
|
||||||
logger.warning("Setting page size to 64 for DeepSeek NSA.")
|
logger.warning("Setting page size to 64 for DeepSeek NSA.")
|
||||||
|
|
||||||
self.mem_fraction_static = 0.8
|
|
||||||
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
|
|
||||||
|
|
||||||
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
|
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user