Disable all two stream overlap on amd (#6475)

This commit is contained in:
Lianmin Zheng
2025-05-20 19:06:59 -07:00
committed by GitHub
parent 66324895c6
commit 03886917bd
3 changed files with 21 additions and 9 deletions

View File

@@ -38,11 +38,17 @@ import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend
from sglang.srt.utils import (
debug_timing,
get_compiler_backend,
is_cuda,
next_power_of_2,
)
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
class ReqToTokenPool:
@@ -262,7 +268,7 @@ class MHATokenToKVPool(KVCache):
self.layer_transfer_counter = None
self.capture_mode = False
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream()
self.alt_stream = self.device_module.Stream() if is_cuda else None
k_size, v_size = self.get_kv_size_bytes()
logger.info(
@@ -392,7 +398,7 @@ class MHATokenToKVPool(KVCache):
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if self.capture_mode and cache_k.shape[0] < 4:
if self.capture_mode and self.alt_stream is not None:
# Overlap the copy of K and V cache for small batch size
current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream)