[feat] add fa3 in sgl-kernel (#4902)
Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
@@ -92,6 +92,36 @@ Steps to add a new kernel:
|
||||
)
|
||||
```
|
||||
|
||||
### Integrating Third-Party Libraries with Data Type Conversion
|
||||
|
||||
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
|
||||
|
||||
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
|
||||
|
||||
When you need to support new data type conversions, you can easily add conversion functions like this:
|
||||
|
||||
```cpp
|
||||
// Map `int` -> `int64_t`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<int> {
|
||||
using type = int64_t;
|
||||
static int convert_from_type(int64_t arg) {
|
||||
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
|
||||
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
To use this with your library functions, simply wrap them with make_pytorch_shim:
|
||||
|
||||
```cpp
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
m.def("fwd", make_pytorch_shim(mha_fwd));
|
||||
```
|
||||
|
||||
### Build & Install
|
||||
|
||||
Development build:
|
||||
|
||||
Reference in New Issue
Block a user