[Fix] fix fa3 build at cu118 (#5036)

This commit is contained in:
yinfan98
2025-04-04 02:52:35 +08:00
committed by GitHub
parent 8e10fec9a8
commit b8b6008f47
8 changed files with 288 additions and 142 deletions

View File

@@ -3,15 +3,22 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
try:
from sgl_kernel import flash_ops
except:
raise ImportError("Can not import sgl_kernel. Please check your installation.")
def is_fa3_supported(device=None) -> bool:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
return FA3_AVAILABLE and (
torch.cuda.get_device_capability(device)[0] >= 9
or torch.cuda.get_device_capability(device) == (8, 0)
or torch.cuda.get_device_capability(device) == (8, 7)
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
return (
(torch.cuda.get_device_capability(device)[0] >= 9)
and (torch.version.cuda >= "12.4")
# or torch.cuda.get_device_capability(device) == (8, 0)
# or torch.cuda.get_device_capability(device) == (8, 7)
)
@@ -135,6 +142,10 @@ def flash_attn_with_kvcache(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None: