[feat] add fa3 in sgl-kernel (#4902)

Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
yinfan98
2025-03-31 03:57:10 +08:00
committed by GitHub
parent 9adf178cc2
commit 37c66ec856
7 changed files with 1300 additions and 0 deletions

View File

@@ -92,6 +92,36 @@ Steps to add a new kernel:
)
```
### Integrating Third-Party Libraries with Data Type Conversion
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
When you need to support new data type conversions, you can easily add conversion functions like this:
```cpp
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
```
To use this with your library functions, simply wrap them with make_pytorch_shim:
```cpp
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
```
### Build & Install
Development build: