From 323bc2f51a701737a05e883f3374ec0c91c7eab9 Mon Sep 17 00:00:00 2001 From: Chaitanya Sri Krishna Lolla Date: Sat, 9 Aug 2025 14:29:55 +0530 Subject: [PATCH] Enable TBO on ROCm (#8329) --- python/sglang/srt/two_batch_overlap.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 8e84b539b..23580a463 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import BumpAllocator, get_bool_env_var +from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import DispatchOutput +_is_hip = is_hip() + _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") logger = logging.getLogger(__name__) @@ -822,9 +824,15 @@ def _model_forward_tbo( ) del inputs - with deep_gemm_wrapper.configure_deep_gemm_num_sms( - operations_strategy.deep_gemm_num_sms - ): + context = ( + empty_context() + if _is_hip + else deep_gemm_wrapper.configure_deep_gemm_num_sms( + operations_strategy.deep_gemm_num_sms + ) + ) + + with context: outputs_arr = execute_overlapped_operations( inputs_arr=inputs_arr, operations_arr=[operations_strategy.operations] * 2,