From 86d9baedc2be9b44613ab6ded4b386bf4c6a84d9 Mon Sep 17 00:00:00 2001 From: Chen Shengzhi Date: Sun, 16 Mar 2025 07:33:00 +0800 Subject: [PATCH] [Fix] Fix errors when using the device except cuda. (#4455) --- python/sglang/srt/mem_cache/memory_pool.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 16bb8eb60..de30aab25 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -227,7 +227,8 @@ class MHATokenToKVPool(KVCache): self.layer_transfer_counter = None self.capture_mode = False - self.alt_stream = torch.cuda.Stream() + self.device_module = torch.get_device_module(self.device) + self.alt_stream = self.device_module.Stream() k_size, v_size = self.get_kv_size_bytes() logger.info( @@ -339,11 +340,12 @@ class MHATokenToKVPool(KVCache): cache_v = cache_v.view(self.store_dtype) if self.capture_mode and cache_k.shape[0] < 4: - self.alt_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.alt_stream): + current_stream = self.device_module.current_stream() + self.alt_stream.wait_stream(current_stream) + with self.device_module.stream(self.alt_stream): self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v - torch.cuda.current_stream().wait_stream(self.alt_stream) + current_stream.wait_stream(self.alt_stream) else: self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v