[ROCm] To reduce the compiling time when using torch compile. (#10559)
This commit is contained in:
@@ -53,7 +53,9 @@ from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
get_device_memory_capacity,
|
||||
is_hip,
|
||||
log_info_on_rank0,
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
@@ -61,6 +63,8 @@ from sglang.srt.utils import (
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -137,7 +141,7 @@ def patch_model(
|
||||
mode=os.environ.get(
|
||||
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
||||
),
|
||||
dynamic=False,
|
||||
dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
|
||||
)
|
||||
else:
|
||||
yield model.forward
|
||||
|
||||
Reference in New Issue
Block a user