diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 8e84b539b..23580a463 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.utils import BumpAllocator, get_bool_env_var +from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import DispatchOutput +_is_hip = is_hip() + _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") logger = logging.getLogger(__name__) @@ -822,9 +824,15 @@ def _model_forward_tbo( ) del inputs - with deep_gemm_wrapper.configure_deep_gemm_num_sms( - operations_strategy.deep_gemm_num_sms - ): + context = ( + empty_context() + if _is_hip + else deep_gemm_wrapper.configure_deep_gemm_num_sms( + operations_strategy.deep_gemm_num_sms + ) + ) + + with context: outputs_arr = execute_overlapped_operations( inputs_arr=inputs_arr, operations_arr=[operations_strategy.operations] * 2,