[Fix] Init mamba related memory pools with torch.zeros (#10400)
This commit is contained in:
@@ -127,7 +127,7 @@ class MambaPool:
|
|||||||
if speculative_num_draft_tokens is not None:
|
if speculative_num_draft_tokens is not None:
|
||||||
# Cache intermediate SSM states per draft token during target verify
|
# Cache intermediate SSM states per draft token during target verify
|
||||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||||
intermediate_ssm_state_cache = torch.empty(
|
intermediate_ssm_state_cache = torch.zeros(
|
||||||
size=(
|
size=(
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
size + 1,
|
size + 1,
|
||||||
@@ -141,7 +141,7 @@ class MambaPool:
|
|||||||
)
|
)
|
||||||
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
||||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
||||||
intermediate_conv_window_cache = torch.empty(
|
intermediate_conv_window_cache = torch.zeros(
|
||||||
size=(
|
size=(
|
||||||
num_mamba_layers,
|
num_mamba_layers,
|
||||||
size + 1,
|
size + 1,
|
||||||
@@ -240,7 +240,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|||||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
|
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
||||||
size, dtype=torch.int32, device=self.device
|
size, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user