[Fix] Fix errors when using the device except cuda. (#4455)
This commit is contained in:
@@ -227,7 +227,8 @@ class MHATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.capture_mode = False
|
self.capture_mode = False
|
||||||
self.alt_stream = torch.cuda.Stream()
|
self.device_module = torch.get_device_module(self.device)
|
||||||
|
self.alt_stream = self.device_module.Stream()
|
||||||
|
|
||||||
k_size, v_size = self.get_kv_size_bytes()
|
k_size, v_size = self.get_kv_size_bytes()
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -339,11 +340,12 @@ class MHATokenToKVPool(KVCache):
|
|||||||
cache_v = cache_v.view(self.store_dtype)
|
cache_v = cache_v.view(self.store_dtype)
|
||||||
|
|
||||||
if self.capture_mode and cache_k.shape[0] < 4:
|
if self.capture_mode and cache_k.shape[0] < 4:
|
||||||
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
current_stream = self.device_module.current_stream()
|
||||||
with torch.cuda.stream(self.alt_stream):
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
with self.device_module.stream(self.alt_stream):
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id][loc] = cache_k
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id][loc] = cache_v
|
||||||
torch.cuda.current_stream().wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
else:
|
else:
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id][loc] = cache_k
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id][loc] = cache_v
|
||||||
|
|||||||
Reference in New Issue
Block a user