Qwen3-Next support (#10233)

Co-authored-by: cao1zhg <114661107+cao1zhg@users.noreply.github.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Binyao Jiang <byjiang1996@gmail.com>
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: Lifu Huang <lifu.hlf@gmail.com>
Co-authored-by: qingquansong <ustcsqq@gmail.com>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Co-authored-by: Ke Bao <ISPObaoke@163.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
This commit is contained in:
Yi Zhang
2025-09-11 19:11:49 +08:00
committed by GitHub
parent bfe01a5eef
commit 30c6e1f569
19 changed files with 3224 additions and 8 deletions

View File

@@ -102,6 +102,204 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class MambaPool:
def __init__(
self,
size: int,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
num_mamba_layers: int,
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
device=device,
)
if speculative_num_draft_tokens is not None:
mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
),
dtype=conv_dtype,
device="cuda",
)
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
temporal_state_shape[0],
temporal_state_shape[1],
temporal_state_shape[2],
),
dtype=ssm_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
mixed_qkv_cache,
intermediate_ssm_state_cache,
)
else:
self.mamba_cache = (conv_state, temporal_state)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
)
def get_mamba_params_all_layers(self):
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
def get_mamba_params(self, layer_id: int):
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
def get_mamba_size(self):
return (
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
)
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> Optional[List[int]]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
def clear(self):
self.free_slots = list(range(self.size))
class HybridReqToTokenPool(ReqToTokenPool):
"""A memory pool that maps a request to its token locations."""
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
mamba_layers: List[int],
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
speculative_num_draft_tokens: int,
):
super().__init__(
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
)
self.mamba_pool = MambaPool(
size,
conv_dtype,
ssm_dtype,
len(mamba_layers),
conv_state_shape,
temporal_state_shape,
device,
speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
size, dtype=torch.int32, device=self.device
)
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
# For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead.
def alloc(
self, need_size: int, reqs: Optional[List["Req"]] = None
) -> Optional[List[int]]:
select_index = super().alloc(need_size)
if select_index == None:
return None
mamba_index = []
for req in reqs:
rid = req.rid
if rid in self.rid_to_mamba_index_mapping:
mid = self.rid_to_mamba_index_mapping[rid]
elif (mid := self.mamba_pool.alloc(1)) is not None:
mid = mid[0]
self.rid_to_mamba_index_mapping[rid] = mid
self.mamba_index_to_rid_mapping[mid] = rid
mamba_index.append(mid)
assert len(select_index) == len(
mamba_index
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
mamba_index, dtype=torch.int32, device=self.device
)
return select_index
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices]
def get_mamba_params(self, layer_id: int):
assert layer_id in self.mamba_map
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
def get_mamba_params_all_layers(self):
return self.mamba_pool.get_mamba_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
super().free(free_index)
if free_mamba_cache:
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
mamba_index_list = mamba_index.tolist()
if isinstance(mamba_index_list, int):
mamba_index_list = [mamba_index_list]
self.mamba_pool.free(mamba_index_list)
for mid in mamba_index_list:
rid = self.mamba_index_to_rid_mapping[mid]
self.mamba_index_to_rid_mapping.pop(mid)
self.rid_to_mamba_index_mapping.pop(rid)
def clear(self):
super().clear()
self.mamba_pool.clear()
class KVCache(abc.ABC):
@abc.abstractmethod
def __init__(
@@ -441,6 +639,88 @@ class MHATokenToKVPool(KVCache):
)
class HybridLinearKVPool(KVCache):
"""KV cache with separate pools for full and linear attention layers."""
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
):
self.size = size
self.dtype = dtype
self.device = device
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
self.full_kv_pool = MHATokenToKVPool(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids)
}
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes()
def get_contiguous_buf_infos(self):
return self.full_kv_pool.get_contiguous_buf_infos()
def _transfer_full_attention_id(self, layer_id: int):
if layer_id not in self.full_attention_layer_id_mapping:
raise ValueError(
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
)
return self.full_attention_layer_id_mapping[layer_id]
def get_key_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_key_buffer(layer_id)
def get_value_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""