[Fix] Support qwen3-next MTP+DP (#10392)

This commit is contained in:
Binyao Jiang
2025-09-13 02:45:04 -07:00
committed by GitHub
parent 297d374510
commit 9752861002
4 changed files with 29 additions and 18 deletions

View File

@@ -52,6 +52,10 @@ if _is_npu:
import torch_npu
def get_tensor_size_bytes(t: torch.Tensor):
return np.prod(t.shape) * t.dtype.itemsize
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
@@ -158,16 +162,23 @@ class MambaPool:
intermediate_ssm_state_cache,
intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
self.mamba_cache = (conv_state, temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
)
def get_mamba_params_all_layers(self):
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
@@ -176,10 +187,7 @@ class MambaPool:
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
def get_mamba_size(self):
return (
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
)
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
def available_size(self):
return len(self.free_slots)
@@ -492,10 +500,10 @@ class MHATokenToKVPool(KVCache):
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
k_size_bytes += get_tensor_size_bytes(k_cache)
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
v_size_bytes += get_tensor_size_bytes(v_cache)
return k_size_bytes, v_size_bytes
# for disagg
@@ -1077,7 +1085,7 @@ class MLATokenToKVPool(KVCache):
assert hasattr(self, "kv_buffer")
kv_size_bytes = 0
for kv_cache in self.kv_buffer:
kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
kv_size_bytes += get_tensor_size_bytes(kv_cache)
return kv_size_bytes
# for disagg
@@ -1240,9 +1248,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
assert hasattr(self, "v_buffer")
kv_size_bytes = 0
for k_cache in self.k_buffer:
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
kv_size_bytes += get_tensor_size_bytes(k_cache)
for v_cache in self.v_buffer:
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
kv_size_bytes += get_tensor_size_bytes(v_cache)
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):