Enable TBO on ROCm (#8329)
This commit is contained in:
committed by
GitHub
parent
137e75daa1
commit
323bc2f51a
@@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||||
from sglang.srt.operations_strategy import OperationsStrategy
|
from sglang.srt.operations_strategy import OperationsStrategy
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import BumpAllocator, get_bool_env_var
|
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -822,9 +824,15 @@ def _model_forward_tbo(
|
|||||||
)
|
)
|
||||||
del inputs
|
del inputs
|
||||||
|
|
||||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
context = (
|
||||||
operations_strategy.deep_gemm_num_sms
|
empty_context()
|
||||||
):
|
if _is_hip
|
||||||
|
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||||
|
operations_strategy.deep_gemm_num_sms
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with context:
|
||||||
outputs_arr = execute_overlapped_operations(
|
outputs_arr = execute_overlapped_operations(
|
||||||
inputs_arr=inputs_arr,
|
inputs_arr=inputs_arr,
|
||||||
operations_arr=[operations_strategy.operations] * 2,
|
operations_arr=[operations_strategy.operations] * 2,
|
||||||
|
|||||||
Reference in New Issue
Block a user