Fix a regression introduced by overlapping KV cache writing (#4375)
This commit is contained in:
@@ -326,7 +326,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
cache_k = cache_k.view(self.store_dtype)
|
cache_k = cache_k.view(self.store_dtype)
|
||||||
cache_v = cache_v.view(self.store_dtype)
|
cache_v = cache_v.view(self.store_dtype)
|
||||||
|
|
||||||
if self.capture_mode:
|
if self.capture_mode and cache_k.shape[0] < 4:
|
||||||
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
self.alt_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id][loc] = cache_k
|
||||||
|
|||||||
Reference in New Issue
Block a user