fix accuracy issue (#4376)

This commit is contained in:
Yineng Zhang
2025-03-13 02:06:22 -07:00
committed by GitHub
parent cf721fdece
commit 2937387a50
4 changed files with 17 additions and 5 deletions

View File

@@ -178,6 +178,8 @@ if torch.cuda.is_available():
if cuda_version >= (12, 8) and sm_version >= 100:
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
else:
nvcc_flags.append("-use_fast_math")
if sm_version >= 90:
nvcc_flags.extend(nvcc_flags_fp8)
if sm_version >= 80:
@@ -188,6 +190,8 @@ else:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_sm100a:
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
else:
nvcc_flags.append("-use_fast_math")
if enable_fp8:
nvcc_flags.extend(nvcc_flags_fp8)
if enable_bf16: