From 03886917bd59f12a1420a99150997732ffea52da Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 20 May 2025 19:06:59 -0700 Subject: [PATCH] Disable all two stream overlap on amd (#6475) --- python/sglang/srt/mem_cache/memory_pool.py | 12 +++++++++--- python/sglang/srt/models/deepseek_v2.py | 6 ++---- python/sglang/srt/models/llama4.py | 12 ++++++++++-- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index cd7d653fc..8c88b9436 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -38,11 +38,17 @@ import triton import triton.language as tl from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import debug_timing, get_compiler_backend +from sglang.srt.utils import ( + debug_timing, + get_compiler_backend, + is_cuda, + next_power_of_2, +) logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 +_is_cuda = is_cuda() class ReqToTokenPool: @@ -262,7 +268,7 @@ class MHATokenToKVPool(KVCache): self.layer_transfer_counter = None self.capture_mode = False self.device_module = torch.get_device_module(self.device) - self.alt_stream = self.device_module.Stream() + self.alt_stream = self.device_module.Stream() if is_cuda else None k_size, v_size = self.get_kv_size_bytes() logger.info( @@ -392,7 +398,7 @@ class MHATokenToKVPool(KVCache): cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - if self.capture_mode and cache_k.shape[0] < 4: + if self.capture_mode and self.alt_stream is not None: # Overlap the copy of K and V cache for small batch size current_stream = self.device_module.current_stream() self.alt_stream.wait_stream(current_stream) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 849e3c76d..0974102e3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -76,13 +76,12 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import ( - ExpertDistributionRecorder, get_global_expert_distribution_recorder, ) from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.operations import execute_operations from sglang.srt.operations_strategy import compute_layer_operations @@ -1321,8 +1320,7 @@ class DeepseekV2Model(nn.Module): config.hidden_size, enable_tp=not global_server_args_dict["enable_dp_attention"], ) - # TODO(haishaw): multi-stream performance on ROCm - self.alt_stream = None if _is_hip else torch.cuda.Stream() + self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.layers = nn.ModuleList( [ DeepseekV2DecoderLayer( diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index d309d0be1..082b97ae0 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import ( PPProxyTensors, ) from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP -from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers +from sglang.srt.utils import ( + add_prefix, + fast_topk, + get_compiler_backend, + is_cuda, + make_layers, +) + +_is_cuda = is_cuda() logger = logging.getLogger(__name__) @@ -131,7 +139,7 @@ class Llama4MoE(nn.Module): return out_aD def _forward_core(self, hidden_states, forward_mode: ForwardMode): - if hidden_states.shape[0] < 4: + if hidden_states.shape[0] < 4 and _is_cuda: return self._forward_core_shared_routed_overlap(hidden_states) else: return self._forward_core_normal(hidden_states)