Fix data parallel + tensor parallel (#4499)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
@@ -81,7 +82,9 @@ def patch_model(
|
||||
# tp_group.ca_comm = None
|
||||
yield torch.compile(
|
||||
torch.no_grad()(model.forward),
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
mode=os.environ.get(
|
||||
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
||||
),
|
||||
dynamic=False,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user