Enable TBO on ROCm (#8329)
This commit is contained in:
committed by
GitHub
parent
137e75daa1
commit
323bc2f51a
@@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import BumpAllocator, get_bool_env_var
|
||||
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -822,9 +824,15 @@ def _model_forward_tbo(
|
||||
)
|
||||
del inputs
|
||||
|
||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
operations_strategy.deep_gemm_num_sms
|
||||
):
|
||||
context = (
|
||||
empty_context()
|
||||
if _is_hip
|
||||
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
operations_strategy.deep_gemm_num_sms
|
||||
)
|
||||
)
|
||||
|
||||
with context:
|
||||
outputs_arr = execute_overlapped_operations(
|
||||
inputs_arr=inputs_arr,
|
||||
operations_arr=[operations_strategy.operations] * 2,
|
||||
|
||||
Reference in New Issue
Block a user