diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 16bb8eb60..de30aab25 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -227,7 +227,8 @@ class MHATokenToKVPool(KVCache): self.layer_transfer_counter = None self.capture_mode = False - self.alt_stream = torch.cuda.Stream() + self.device_module = torch.get_device_module(self.device) + self.alt_stream = self.device_module.Stream() k_size, v_size = self.get_kv_size_bytes() logger.info( @@ -339,11 +340,12 @@ class MHATokenToKVPool(KVCache): cache_v = cache_v.view(self.store_dtype) if self.capture_mode and cache_k.shape[0] < 4: - self.alt_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.alt_stream): + current_stream = self.device_module.current_stream() + self.alt_stream.wait_stream(current_stream) + with self.device_module.stream(self.alt_stream): self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v - torch.cuda.current_stream().wait_stream(self.alt_stream) + current_stream.wait_stream(self.alt_stream) else: self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v