[Feat] Scale up fa3 kernel to sm8x arch (#5912)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -233,7 +233,7 @@ install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
|
|||||||
|
|
||||||
# ============================ Optional Install ============================= #
|
# ============================ Optional Install ============================= #
|
||||||
# set flash-attention sources file
|
# set flash-attention sources file
|
||||||
# BF16 source files
|
# Now FA3 support sm80/sm86/sm90
|
||||||
if (SGL_KERNEL_ENABLE_FA3)
|
if (SGL_KERNEL_ENABLE_FA3)
|
||||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||||
"-DNDEBUG"
|
"-DNDEBUG"
|
||||||
@@ -241,6 +241,8 @@ if (SGL_KERNEL_ENABLE_FA3)
|
|||||||
"-O3"
|
"-O3"
|
||||||
"-Xcompiler"
|
"-Xcompiler"
|
||||||
"-fPIC"
|
"-fPIC"
|
||||||
|
"-gencode=arch=compute_80,code=sm_80"
|
||||||
|
"-gencode=arch=compute_86,code=sm_86"
|
||||||
"-gencode=arch=compute_90a,code=sm_90a"
|
"-gencode=arch=compute_90a,code=sm_90a"
|
||||||
"-std=c++17"
|
"-std=c++17"
|
||||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||||
@@ -256,6 +258,10 @@ if (SGL_KERNEL_ENABLE_FA3)
|
|||||||
"-Xcompiler=-fno-strict-aliasing"
|
"-Xcompiler=-fno-strict-aliasing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SM8X Logic
|
||||||
|
file(GLOB FA3_SM8X_GEN_SRCS
|
||||||
|
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
|
||||||
|
|
||||||
file(GLOB FA3_BF16_GEN_SRCS
|
file(GLOB FA3_BF16_GEN_SRCS
|
||||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||||
file(GLOB FA3_BF16_GEN_SRCS_
|
file(GLOB FA3_BF16_GEN_SRCS_
|
||||||
@@ -276,7 +282,7 @@ if (SGL_KERNEL_ENABLE_FA3)
|
|||||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||||
|
|
||||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
|
||||||
|
|
||||||
set(FLASH_SOURCES
|
set(FLASH_SOURCES
|
||||||
"csrc/flash_extension.cc"
|
"csrc/flash_extension.cc"
|
||||||
@@ -297,7 +303,7 @@ if (SGL_KERNEL_ENABLE_FA3)
|
|||||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||||
|
|
||||||
target_compile_definitions(flash_ops PRIVATE
|
target_compile_definitions(flash_ops PRIVATE
|
||||||
FLASHATTENTION_DISABLE_SM8x
|
# FLASHATTENTION_DISABLE_SM8x
|
||||||
FLASHATTENTION_DISABLE_BACKWARD
|
FLASHATTENTION_DISABLE_BACKWARD
|
||||||
FLASHATTENTION_DISABLE_DROPOUT
|
FLASHATTENTION_DISABLE_DROPOUT
|
||||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||||
|
|||||||
@@ -81,6 +81,14 @@ Third-party libraries:
|
|||||||
- [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM)
|
- [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM)
|
||||||
- [FlashAttention](https://github.com/Dao-AILab/flash-attention)
|
- [FlashAttention](https://github.com/Dao-AILab/flash-attention)
|
||||||
|
|
||||||
|
### FlashAttention FYI
|
||||||
|
|
||||||
|
FA3 can fail without a enough shared memory for a some shapes, such as higher hidden_dim or some special cases. Right now, fa3 is supported for sm80/sm87 and sm86/sm89.
|
||||||
|
|
||||||
|
The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x.
|
||||||
|
|
||||||
|
And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. Thats mean if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.
|
||||||
|
|
||||||
### Kernel Development
|
### Kernel Development
|
||||||
|
|
||||||
Steps to add a new kernel:
|
Steps to add a new kernel:
|
||||||
|
|||||||
@@ -10,10 +10,18 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
def is_fa3_supported(device=None) -> bool:
|
def is_fa3_supported(device=None) -> bool:
|
||||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.3
|
# There some fa3 FYI
|
||||||
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
||||||
torch.version.cuda >= "12.3"
|
# hidden_dim or some special cases.
|
||||||
)
|
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
|
||||||
|
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||||
|
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||||
|
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||||
|
return (
|
||||||
|
torch.cuda.get_device_capability(device)[0] == 9
|
||||||
|
or torch.cuda.get_device_capability(device)[0] == 8
|
||||||
|
) and (torch.version.cuda >= "12.3")
|
||||||
|
|
||||||
|
|
||||||
def maybe_contiguous(x):
|
def maybe_contiguous(x):
|
||||||
|
|||||||
@@ -11,17 +11,24 @@ from einops import rearrange, repeat
|
|||||||
apply_rotary_emb = None
|
apply_rotary_emb = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_hopper():
|
||||||
|
# Only Hopper supports different V headdim
|
||||||
|
return torch.cuda.get_device_properties(0).major >= 9
|
||||||
|
|
||||||
|
|
||||||
def is_fa3_supported(device=None) -> bool:
|
def is_fa3_supported(device=None) -> bool:
|
||||||
# FA3 can fail without a enough shared memory for a some shapes, currently
|
# There some fa3 FYI
|
||||||
# only 8.0 and 8.7 have enough shared memory for all shapes
|
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
||||||
|
# hidden_dim or some special cases.
|
||||||
|
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
|
||||||
|
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||||
|
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||||
return (
|
return (
|
||||||
(torch.cuda.get_device_capability(device)[0] == 9)
|
torch.cuda.get_device_capability(device)[0] == 9
|
||||||
and (torch.version.cuda >= "12.4")
|
or torch.cuda.get_device_capability(device)[0] == 8
|
||||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
) and (torch.version.cuda >= "12.3")
|
||||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DISABLE_BACKWARD = True
|
DISABLE_BACKWARD = True
|
||||||
@@ -558,7 +565,8 @@ def test_flash_attn_kvcache(
|
|||||||
assert nheads % nheads_k == 0
|
assert nheads % nheads_k == 0
|
||||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||||
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn or not is_hopper():
|
||||||
|
# for fp8 and ampere arch, we not support v head dim != qk head dim
|
||||||
dv_vals = [d]
|
dv_vals = [d]
|
||||||
for dv in dv_vals:
|
for dv in dv_vals:
|
||||||
has_qv = d == 64 and dv >= 256
|
has_qv = d == 64 and dv >= 256
|
||||||
|
|||||||
Reference in New Issue
Block a user