[Fix]Fix capture fail bug for DeepSeek (#6275)
This commit is contained in:
@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self._create_buffers()
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
self.capture_mode = False
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
||||
|
||||
@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache):
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache):
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
if self.capture_mode and self.alt_stream is not None:
|
||||
if get_is_capture_mode() and self.alt_stream is not None:
|
||||
# Overlap the copy of K and V cache for small batch size
|
||||
current_stream = self.device_module.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
|
||||
Reference in New Issue
Block a user