[ROCm] To reduce the compiling time when using torch compile. (#10559)

This commit is contained in:
sogalin
2025-10-02 14:53:14 +08:00
committed by GitHub
parent 25e7dbe8af
commit c0dbbdd12b

View File

@@ -53,7 +53,9 @@ from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
from sglang.srt.utils import ( from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var,
get_device_memory_capacity, get_device_memory_capacity,
is_hip,
log_info_on_rank0, log_info_on_rank0,
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
@@ -61,6 +63,8 @@ from sglang.srt.utils import (
require_mlp_tp_gather, require_mlp_tp_gather,
) )
_is_hip = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -137,7 +141,7 @@ def patch_model(
mode=os.environ.get( mode=os.environ.get(
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
), ),
dynamic=False, dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
) )
else: else:
yield model.forward yield model.forward