[Fix] revert clean m.def for cudagraph (#4944)
This commit is contained in:
@@ -55,9 +55,17 @@ Steps to add a new kernel:
|
||||
|
||||
1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.
|
||||
|
||||
2. When creating torch extensions, simply add the function definition with `m.def`:
|
||||
2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`:
|
||||
- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9)
|
||||
|
||||
- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func)
|
||||
|
||||
```cpp
|
||||
m.def("register_graph_buffers", register_graph_buffers);
|
||||
// We need def with schema here for torch.compile
|
||||
m.def(
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
```
|
||||
|
||||
3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
|
||||
@@ -96,6 +104,8 @@ Steps to add a new kernel:
|
||||
|
||||
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
|
||||
|
||||
> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++.
|
||||
|
||||
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
|
||||
|
||||
When you need to support new data type conversions, you can easily add conversion functions like this:
|
||||
@@ -119,7 +129,7 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
m.def("fwd", make_pytorch_shim(mha_fwd));
|
||||
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
||||
```
|
||||
|
||||
### Build & Install
|
||||
|
||||
Reference in New Issue
Block a user