Fix data parallel + tensor parallel (#4499)

This commit is contained in:
Lianmin Zheng
2025-03-17 05:13:16 -07:00
committed by GitHub
parent f2ab37e500
commit 5493c3343e
6 changed files with 53 additions and 16 deletions

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import bisect
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable
@@ -81,7 +82,9 @@ def patch_model(
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward),
mode="max-autotune-no-cudagraphs",
mode=os.environ.get(
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
),
dynamic=False,
)
else: