[ROCm] To reduce the compiling time when using torch compile. (#10559)
This commit is contained in:
@@ -53,7 +53,9 @@ from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
empty_context,
|
empty_context,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
|
get_bool_env_var,
|
||||||
get_device_memory_capacity,
|
get_device_memory_capacity,
|
||||||
|
is_hip,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
require_gathered_buffer,
|
require_gathered_buffer,
|
||||||
@@ -61,6 +63,8 @@ from sglang.srt.utils import (
|
|||||||
require_mlp_tp_gather,
|
require_mlp_tp_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -137,7 +141,7 @@ def patch_model(
|
|||||||
mode=os.environ.get(
|
mode=os.environ.get(
|
||||||
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
||||||
),
|
),
|
||||||
dynamic=False,
|
dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield model.forward
|
yield model.forward
|
||||||
|
|||||||
Reference in New Issue
Block a user