fix accuracy issue (#4376)
This commit is contained in:
@@ -178,6 +178,8 @@ if torch.cuda.is_available():
|
||||
if cuda_version >= (12, 8) and sm_version >= 100:
|
||||
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
|
||||
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
|
||||
else:
|
||||
nvcc_flags.append("-use_fast_math")
|
||||
if sm_version >= 90:
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if sm_version >= 80:
|
||||
@@ -188,6 +190,8 @@ else:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
if enable_sm100a:
|
||||
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
|
||||
else:
|
||||
nvcc_flags.append("-use_fast_math")
|
||||
if enable_fp8:
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if enable_bf16:
|
||||
|
||||
Reference in New Issue
Block a user