Qwen3-Next support (#10233)
Co-authored-by: cao1zhg <114661107+cao1zhg@users.noreply.github.com> Co-authored-by: ispobock <ispobaoke@gmail.com> Co-authored-by: Binyao Jiang <byjiang1996@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: Lifu Huang <lifu.hlf@gmail.com> Co-authored-by: qingquansong <ustcsqq@gmail.com> Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com> Co-authored-by: Ke Bao <ISPObaoke@163.com> Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
This commit is contained in:
@@ -102,6 +102,204 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class MambaPool:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
conv_dtype: torch.dtype,
|
||||
ssm_dtype: torch.dtype,
|
||||
num_mamba_layers: int,
|
||||
conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int],
|
||||
device: str,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
):
|
||||
conv_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
||||
dtype=conv_dtype,
|
||||
device=device,
|
||||
)
|
||||
temporal_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
||||
dtype=ssm_dtype,
|
||||
device=device,
|
||||
)
|
||||
if speculative_num_draft_tokens is not None:
|
||||
mixed_qkv_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
conv_state_shape[0],
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
# Cache intermediate SSM states per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||
intermediate_ssm_state_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
temporal_state_shape[0],
|
||||
temporal_state_shape[1],
|
||||
temporal_state_shape[2],
|
||||
),
|
||||
dtype=ssm_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.mamba_cache = (
|
||||
conv_state,
|
||||
temporal_state,
|
||||
mixed_qkv_cache,
|
||||
intermediate_ssm_state_cache,
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
self.size = size
|
||||
self.free_slots = list(range(size))
|
||||
self.mem_usage = self.get_mamba_size() / GB
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
|
||||
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
|
||||
)
|
||||
|
||||
def get_mamba_params_all_layers(self):
|
||||
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
||||
|
||||
def get_mamba_params(self, layer_id: int):
|
||||
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
||||
|
||||
def get_mamba_size(self):
|
||||
return (
|
||||
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
|
||||
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
|
||||
)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> Optional[List[int]]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class HybridReqToTokenPool(ReqToTokenPool):
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
conv_dtype: torch.dtype,
|
||||
ssm_dtype: torch.dtype,
|
||||
mamba_layers: List[int],
|
||||
conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int],
|
||||
speculative_num_draft_tokens: int,
|
||||
):
|
||||
super().__init__(
|
||||
size=size,
|
||||
max_context_len=max_context_len,
|
||||
device=device,
|
||||
enable_memory_saver=enable_memory_saver,
|
||||
)
|
||||
|
||||
self.mamba_pool = MambaPool(
|
||||
size,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
len(mamba_layers),
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
device,
|
||||
speculative_num_draft_tokens,
|
||||
)
|
||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
||||
|
||||
self.device = device
|
||||
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
|
||||
size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
||||
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
||||
|
||||
# For chunk prefill req, we do not need to allocate mamba cache,
|
||||
# We could use allocated mamba cache instead.
|
||||
def alloc(
|
||||
self, need_size: int, reqs: Optional[List["Req"]] = None
|
||||
) -> Optional[List[int]]:
|
||||
select_index = super().alloc(need_size)
|
||||
if select_index == None:
|
||||
return None
|
||||
|
||||
mamba_index = []
|
||||
for req in reqs:
|
||||
rid = req.rid
|
||||
if rid in self.rid_to_mamba_index_mapping:
|
||||
mid = self.rid_to_mamba_index_mapping[rid]
|
||||
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
||||
mid = mid[0]
|
||||
self.rid_to_mamba_index_mapping[rid] = mid
|
||||
self.mamba_index_to_rid_mapping[mid] = rid
|
||||
mamba_index.append(mid)
|
||||
assert len(select_index) == len(
|
||||
mamba_index
|
||||
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
||||
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
||||
mamba_index, dtype=torch.int32, device=self.device
|
||||
)
|
||||
return select_index
|
||||
|
||||
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
||||
return self.req_index_to_mamba_index_mapping[req_indices]
|
||||
|
||||
def get_mamba_params(self, layer_id: int):
|
||||
assert layer_id in self.mamba_map
|
||||
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
||||
|
||||
def get_mamba_params_all_layers(self):
|
||||
return self.mamba_pool.get_mamba_params_all_layers()
|
||||
|
||||
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
||||
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
||||
super().free(free_index)
|
||||
if free_mamba_cache:
|
||||
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
||||
mamba_index_list = mamba_index.tolist()
|
||||
if isinstance(mamba_index_list, int):
|
||||
mamba_index_list = [mamba_index_list]
|
||||
self.mamba_pool.free(mamba_index_list)
|
||||
for mid in mamba_index_list:
|
||||
rid = self.mamba_index_to_rid_mapping[mid]
|
||||
self.mamba_index_to_rid_mapping.pop(mid)
|
||||
self.rid_to_mamba_index_mapping.pop(rid)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.mamba_pool.clear()
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
@@ -441,6 +639,88 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
|
||||
class HybridLinearKVPool(KVCache):
|
||||
"""KV cache with separate pools for full and linear attention layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
full_attention_layer_ids: List[int],
|
||||
enable_kvcache_transpose: bool,
|
||||
device: str,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.page_size = 1
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
self.full_kv_pool = MHATokenToKVPool(
|
||||
size=size,
|
||||
page_size=self.page_size,
|
||||
dtype=dtype,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
layer_num=self.full_layer_nums,
|
||||
device=device,
|
||||
enable_memory_saver=False,
|
||||
)
|
||||
self.full_attention_layer_id_mapping = {
|
||||
id: i for i, id in enumerate(full_attention_layer_ids)
|
||||
}
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
self.mem_usage = (k_size + v_size) / GB
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
return self.full_kv_pool.get_kv_size_bytes()
|
||||
|
||||
def get_contiguous_buf_infos(self):
|
||||
return self.full_kv_pool.get_contiguous_buf_infos()
|
||||
|
||||
def _transfer_full_attention_id(self, layer_id: int):
|
||||
if layer_id not in self.full_attention_layer_id_mapping:
|
||||
raise ValueError(
|
||||
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
||||
)
|
||||
return self.full_attention_layer_id_mapping[layer_id]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
layer_id = self._transfer_full_attention_id(layer_id)
|
||||
return self.full_kv_pool.get_key_buffer(layer_id)
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
layer_id = self._transfer_full_attention_id(layer_id)
|
||||
return self.full_kv_pool.get_value_buffer(layer_id)
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
layer_id = self._transfer_full_attention_id(layer_id)
|
||||
return self.full_kv_pool.get_kv_buffer(layer_id)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
||||
self.full_kv_pool.set_kv_buffer(
|
||||
None,
|
||||
loc,
|
||||
cache_k,
|
||||
cache_v,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer_id_override=layer_id,
|
||||
)
|
||||
|
||||
|
||||
class SWAKVPool(KVCache):
|
||||
"""KV cache with separate pools for full and SWA attention layers."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user