[feat] add fa3 in sgl-kernel (#4902)
Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
@@ -23,6 +23,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sgl_kernel_torch_shim.h"
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
@@ -291,3 +293,48 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
std::vector<at::Tensor> mha_fwd(
|
||||
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
||||
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
|
||||
// h_k, d) if there is page_table.
|
||||
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
|
||||
// page_size, h_k, dv) if there is page_table.
|
||||
std::optional<const at::Tensor>&
|
||||
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>&
|
||||
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
|
||||
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
||||
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
||||
std::optional<int> max_seqlen_q_,
|
||||
// TODO: check if we need max_seqlen_k
|
||||
std::optional<int> max_seqlen_k_,
|
||||
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
|
||||
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
|
||||
std::optional<const at::Tensor>& leftpad_k_, // b
|
||||
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& seqlens_rotary_, // b
|
||||
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
|
||||
std::optional<at::Tensor>& k_descale_, // (b, h_k)
|
||||
std::optional<at::Tensor>& v_descale_, // (b, h_k)
|
||||
float const softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
float const softcap,
|
||||
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
||||
int num_splits,
|
||||
std::optional<bool> pack_gqa_,
|
||||
int const sm_margin);
|
||||
|
||||
122
sgl-kernel/include/sgl_kernel_torch_shim.h
Normal file
122
sgl-kernel/include/sgl_kernel_torch_shim.h
Normal file
@@ -0,0 +1,122 @@
|
||||
/*Adapt from:
|
||||
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
/**
|
||||
* Unforunately, the type signatures of the flash_attn ops are not compatible
|
||||
* with the PyTorch library bindings. To get around that we use
|
||||
* `make_pytorch_shim` which creates a lambda that exponses the API using
|
||||
* PyTorch compatible types to the types, then converts them to the types
|
||||
* expected by the flash_attn ops. This shims allows us to make minimal changes
|
||||
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
|
||||
*
|
||||
* The `pytorch_library_compatible_type` struct is used to map from the
|
||||
* flash_attn ops types to a PyTorch library compatible one. The main issues is
|
||||
* that the following types are not support by PyTorch libary bindings:
|
||||
* - `int`
|
||||
* - `float`
|
||||
* - `std::optional<T> &`
|
||||
* - `std::optional<const at::Tensor> &`
|
||||
* So we convert them to (respectively):
|
||||
* - `int64_t`
|
||||
* - `double`
|
||||
* - `const std::optional<T>&`
|
||||
* - `const std::optional<at::Tensor>&`
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct pytorch_library_compatible_type {
|
||||
using type = T;
|
||||
static T convert_from_type(T arg) {
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
|
||||
|
||||
template <typename T>
|
||||
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
|
||||
return pytorch_library_compatible_type<T>::convert_from_type(arg);
|
||||
}
|
||||
|
||||
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
|
||||
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
|
||||
// the optional container)
|
||||
template <typename T>
|
||||
struct pytorch_library_compatible_type<c10::optional<T>&> {
|
||||
using type = const c10::optional<T>&;
|
||||
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
|
||||
return const_cast<c10::optional<T>&>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
// Map `c10::optional<T>` ->
|
||||
// `c10::optional<pytorch_library_compatible_type_t<T>>`
|
||||
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
|
||||
template <typename T>
|
||||
struct pytorch_library_compatible_type<c10::optional<T>> {
|
||||
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
|
||||
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
|
||||
using type = const c10::optional<at::Tensor>&;
|
||||
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
|
||||
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
|
||||
}
|
||||
};
|
||||
|
||||
// Map `int` -> `int64_t`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<int> {
|
||||
using type = int64_t;
|
||||
static int convert_from_type(int64_t arg) {
|
||||
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
|
||||
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
// Map `float` -> `double`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<float> {
|
||||
using type = double;
|
||||
static float convert_from_type(double arg) {
|
||||
TORCH_CHECK(
|
||||
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Shim Utils
|
||||
//
|
||||
|
||||
template <typename Ret, typename... Args>
|
||||
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
|
||||
return [fun](pytorch_library_compatible_type_t<Args>... args) {
|
||||
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user