Temporay work-around for rocm 7.0.0 alpha with enabling data-parallel issue (#10434)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Sai Enduri <saimanas.enduri@amd.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
||||
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
||||
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
||||
|
||||
_is_hip = is_hip()
|
||||
_USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
|
||||
|
||||
|
||||
class DpPaddingMode(IntEnum):
|
||||
|
||||
@@ -67,7 +71,12 @@ class DpPaddingMode(IntEnum):
|
||||
|
||||
@classmethod
|
||||
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
||||
return cls.MAX_LEN
|
||||
# TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
|
||||
# it can be safely removed later, once RCCL fixed
|
||||
if _USE_ROCM700A_WA:
|
||||
return cls.SUM_LEN
|
||||
else:
|
||||
return cls.MAX_LEN
|
||||
|
||||
|
||||
class _DpGatheredBufferWrapper:
|
||||
|
||||
Reference in New Issue
Block a user