diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index f7eef2120..251d16aee 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -374,9 +374,9 @@ class MHATokenToKVPool(KVCache): # Overlap the copy of K and V cache for small batch size current_stream = self.device_module.current_stream() self.alt_stream.wait_stream(current_stream) + self.k_buffer[layer_id - self.start_layer][loc] = cache_k with self.device_module.stream(self.alt_stream): - self.k_buffer[layer_id - self.start_layer][loc] = cache_k - self.v_buffer[layer_id - self.start_layer][loc] = cache_v + self.v_buffer[layer_id - self.start_layer][loc] = cache_v current_stream.wait_stream(self.alt_stream) else: self.k_buffer[layer_id - self.start_layer][loc] = cache_k