[feat] add fa3 in sgl-kernel (#4902)

Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
yinfan98
2025-03-31 03:57:10 +08:00
committed by GitHub
parent 9adf178cc2
commit 37c66ec856
7 changed files with 1300 additions and 0 deletions

View File

@@ -23,6 +23,8 @@ limitations under the License.
#include <vector>
#include "sgl_kernel_torch_shim.h"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
@@ -291,3 +293,48 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);
/*
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin);

View File

@@ -0,0 +1,122 @@
/*Adapt from:
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <torch/library.h>
/**
* Unforunately, the type signatures of the flash_attn ops are not compatible
* with the PyTorch library bindings. To get around that we use
* `make_pytorch_shim` which creates a lambda that exponses the API using
* PyTorch compatible types to the types, then converts them to the types
* expected by the flash_attn ops. This shims allows us to make minimal changes
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
*
* The `pytorch_library_compatible_type` struct is used to map from the
* flash_attn ops types to a PyTorch library compatible one. The main issues is
* that the following types are not support by PyTorch libary bindings:
* - `int`
* - `float`
* - `std::optional<T> &`
* - `std::optional<const at::Tensor> &`
* So we convert them to (respectively):
* - `int64_t`
* - `double`
* - `const std::optional<T>&`
* - `const std::optional<at::Tensor>&`
*/
template <typename T>
struct pytorch_library_compatible_type {
using type = T;
static T convert_from_type(T arg) {
return arg;
}
};
template <typename T>
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
template <typename T>
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
return pytorch_library_compatible_type<T>::convert_from_type(arg);
}
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
// the optional container)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>&> {
using type = const c10::optional<T>&;
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
return const_cast<c10::optional<T>&>(arg);
}
};
// Map `c10::optional<T>` ->
// `c10::optional<pytorch_library_compatible_type_t<T>>`
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>> {
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
return arg;
}
};
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
template <>
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
using type = const c10::optional<at::Tensor>&;
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
}
};
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
// Map `float` -> `double`
template <>
struct pytorch_library_compatible_type<float> {
using type = double;
static float convert_from_type(double arg) {
TORCH_CHECK(
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
return arg;
}
};
//
// Shim Utils
//
template <typename Ret, typename... Args>
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
return [fun](pytorch_library_compatible_type_t<Args>... args) {
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
};
}