diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index e9db969d3..4ab8635a8 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +import os import platform import sys from pathlib import Path @@ -50,7 +51,17 @@ cxx_flags = ["-O3"] libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"] -amdgpu_target = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] +default_target = "gfx942" +amdgpu_target = os.environ.get("AMDGPU_TARGET", default_target) + +if torch.cuda.is_available(): + try: + amdgpu_target = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] + except Exception as e: + print(f"Warning: Failed to detect GPU properties: {e}") +else: + print(f"Warning: torch.cuda not available. Using default target: {amdgpu_target}") + if amdgpu_target not in ["gfx942", "gfx950"]: print( f"Warning: Unsupported GPU architecture detected '{amdgpu_target}'. Expected 'gfx942' or 'gfx950'."