修改sgl-kernel下的setup_hip.py
This commit is contained in:
@@ -52,30 +52,10 @@ sources = [
|
|||||||
"csrc/kvcacheio/transfer.cu",
|
"csrc/kvcacheio/transfer.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
cxx_flags = ["-O3", "-w"]
|
cxx_flags = ["-O3"]
|
||||||
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
|
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
|
||||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
|
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
|
||||||
|
|
||||||
default_target = "gfx942"
|
|
||||||
amdgpu_target = os.environ.get("AMDGPU_TARGET", default_target)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
amdgpu_target = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to detect GPU properties: {e}")
|
|
||||||
else:
|
|
||||||
print(f"Warning: torch.cuda not available. Using default target: {amdgpu_target}")
|
|
||||||
|
|
||||||
if amdgpu_target not in ["gfx942", "gfx950", "gfx936"]:
|
|
||||||
print(
|
|
||||||
f"Warning: Unsupported GPU architecture detected '{amdgpu_target}'. Expected 'gfx942' or 'gfx950'."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
fp8_macro = (
|
|
||||||
"-DHIP_FP8_TYPE_FNUZ" if amdgpu_target == "gfx942" else "-DHIP_FP8_TYPE_E4M3"
|
|
||||||
)
|
|
||||||
|
|
||||||
hipcc_flags = [
|
hipcc_flags = [
|
||||||
"-DNDEBUG",
|
"-DNDEBUG",
|
||||||
@@ -84,10 +64,8 @@ hipcc_flags = [
|
|||||||
"-Xcompiler",
|
"-Xcompiler",
|
||||||
"-fPIC",
|
"-fPIC",
|
||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
f"--amdgpu-target={amdgpu_target}",
|
|
||||||
"-DENABLE_BF16",
|
"-DENABLE_BF16",
|
||||||
"-DENABLE_FP8",
|
"-DENABLE_FP8",
|
||||||
fp8_macro,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ext_modules = [
|
ext_modules = [
|
||||||
|
|||||||
Reference in New Issue
Block a user