Enable TBO on ROCm (#8329)

This commit is contained in:
Chaitanya Sri Krishna Lolla
2025-08-09 14:29:55 +05:30
committed by GitHub
parent 137e75daa1
commit 323bc2f51a

View File

@@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, get_bool_env_var
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
_is_hip = is_hip()
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
logger = logging.getLogger(__name__)
@@ -822,9 +824,15 @@ def _model_forward_tbo(
)
del inputs
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
operations_strategy.deep_gemm_num_sms
):
context = (
empty_context()
if _is_hip
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
operations_strategy.deep_gemm_num_sms
)
)
with context:
outputs_arr = execute_overlapped_operations(
inputs_arr=inputs_arr,
operations_arr=[operations_strategy.operations] * 2,