diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index ddcb19ea2..1bd684ad3 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -127,7 +127,7 @@ class MambaPool: if speculative_num_draft_tokens is not None: # Cache intermediate SSM states per draft token during target verify # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] - intermediate_ssm_state_cache = torch.empty( + intermediate_ssm_state_cache = torch.zeros( size=( num_mamba_layers, size + 1, @@ -141,7 +141,7 @@ class MambaPool: ) # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] - intermediate_conv_window_cache = torch.empty( + intermediate_conv_window_cache = torch.zeros( size=( num_mamba_layers, size + 1, @@ -240,7 +240,7 @@ class HybridReqToTokenPool(ReqToTokenPool): self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)} self.device = device - self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty( + self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros( size, dtype=torch.int32, device=self.device )