[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm)
|
|||||||
# flashinfer
|
# flashinfer
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-flashinfer
|
repo-flashinfer
|
||||||
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
|
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||||
GIT_TAG sgl-kernel
|
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
|
||||||
GIT_SHALLOW OFF
|
GIT_SHALLOW OFF
|
||||||
)
|
)
|
||||||
FetchContent_Populate(repo-flashinfer)
|
FetchContent_Populate(repo-flashinfer)
|
||||||
|
|||||||
@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
/*
|
/*
|
||||||
* From csrc/elementwise
|
* From csrc/elementwise
|
||||||
*/
|
*/
|
||||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||||
|
|
||||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||||
|
|
||||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||||
|
|
||||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||||
|
|
||||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||||
@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
|
||||||
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
"min_p_val, bool deterministic, Generator? gen) -> ()");
|
||||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||||
|
|
||||||
m.def(
|
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||||
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
|
||||||
"cuda_stream) -> ()");
|
|
||||||
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
|
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
|
||||||
|
|
||||||
m.def(
|
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
|
||||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
|
||||||
"cuda_stream) -> ()");
|
|
||||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
|
||||||
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||||
"cuda_stream) -> ()");
|
|
||||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ limitations under the License.
|
|||||||
|
|
||||||
using namespace flashinfer;
|
using namespace flashinfer;
|
||||||
|
|
||||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
void sgl_fused_add_rmsnorm(
|
||||||
|
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
|
||||||
CHECK_INPUT(input);
|
CHECK_INPUT(input);
|
||||||
CHECK_INPUT(residual);
|
CHECK_INPUT(residual);
|
||||||
CHECK_INPUT(weight);
|
CHECK_INPUT(weight);
|
||||||
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
|
|||||||
static_cast<c_type*>(weight.data_ptr()),
|
static_cast<c_type*>(weight.data_ptr()),
|
||||||
batch_size,
|
batch_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
input.stride(0),
|
||||||
|
residual.stride(0),
|
||||||
eps,
|
eps,
|
||||||
|
enable_pdl,
|
||||||
torch_current_stream);
|
torch_current_stream);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||||
|
|||||||
@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
DType threshold_acc) {
|
DType threshold_acc) {
|
||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
|
|
||||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||||
uint8_t smem_sampling[];
|
uint8_t smem_sampling[];
|
||||||
auto& temp_storage =
|
auto& temp_storage =
|
||||||
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
||||||
|
|
||||||
DType prob_acc = 0.0;
|
DType prob_acc = 0.0;
|
||||||
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
||||||
@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
|
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
|
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
|
||||||
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
|
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
|
||||||
if (aggregate_relu_q_minus_p > u) {
|
if (aggregate_relu_q_minus_p > u) {
|
||||||
break;
|
break;
|
||||||
@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|||||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||||
|
|
||||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
|
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
|
||||||
|
|||||||
@@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
|
|||||||
/*
|
/*
|
||||||
* From csrc/elementwise
|
* From csrc/elementwise
|
||||||
*/
|
*/
|
||||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
|
||||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
void sgl_fused_add_rmsnorm(
|
||||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
|
||||||
void gemma_fused_add_rmsnorm(
|
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
|
||||||
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
|
||||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||||
@@ -254,48 +254,38 @@ void segment_packbits(
|
|||||||
*/
|
*/
|
||||||
void min_p_sampling_from_probs(
|
void min_p_sampling_from_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor output,
|
||||||
at::Tensor samples,
|
std::optional<at::Tensor> maybe_indices,
|
||||||
std::optional<at::Tensor> maybe_min_p_arr,
|
std::optional<at::Tensor> maybe_min_p_arr,
|
||||||
double min_p_val,
|
double min_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
std::optional<at::Generator> gen);
|
||||||
|
|
||||||
void top_k_renorm_probs(
|
void top_k_renorm_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||||
at::Tensor renorm_probs,
|
|
||||||
std::optional<at::Tensor> maybe_top_k_arr,
|
|
||||||
int64_t top_k_val,
|
|
||||||
int64_t cuda_stream);
|
|
||||||
|
|
||||||
void top_p_renorm_probs(
|
void top_p_renorm_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
|
||||||
at::Tensor renorm_probs,
|
|
||||||
std::optional<at::Tensor> maybe_top_p_arr,
|
|
||||||
double top_p_val,
|
|
||||||
int64_t cuda_stream);
|
|
||||||
|
|
||||||
void top_k_top_p_sampling_from_probs(
|
void top_k_top_p_sampling_from_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor output,
|
||||||
at::Tensor samples,
|
std::optional<at::Tensor> maybe_indices,
|
||||||
at::Tensor success,
|
|
||||||
std::optional<at::Tensor> maybe_top_k_arr,
|
std::optional<at::Tensor> maybe_top_k_arr,
|
||||||
double top_k_val,
|
double top_k_val,
|
||||||
std::optional<at::Tensor> maybe_top_p_arr,
|
std::optional<at::Tensor> maybe_top_p_arr,
|
||||||
double top_p_val,
|
double top_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
std::optional<at::Generator> gen);
|
||||||
|
|
||||||
void top_p_sampling_from_probs(
|
void top_p_sampling_from_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor output,
|
||||||
at::Tensor samples,
|
std::optional<at::Tensor> maybe_indices,
|
||||||
at::Tensor success,
|
|
||||||
std::optional<at::Tensor> maybe_top_p_arr,
|
std::optional<at::Tensor> maybe_top_p_arr,
|
||||||
double top_p_val,
|
double top_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
std::optional<at::Generator> gen);
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
/*
|
/*
|
||||||
|
|||||||
@@ -11,17 +11,69 @@ def rmsnorm(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
out: Optional[torch.Tensor] = None,
|
out: Optional[torch.Tensor] = None,
|
||||||
|
enable_pdl: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
r"""Root mean square normalization.
|
||||||
|
|
||||||
|
``out[i] = (input[i] / RMS(input)) * weight[i]``
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input: torch.Tensor
|
||||||
|
Input tensor, shape (batch_size, hidden_size).
|
||||||
|
weight: torch.Tensor
|
||||||
|
Weight tensor, shape (hidden_size,).
|
||||||
|
eps: float
|
||||||
|
Epsilon for numerical stability.
|
||||||
|
out: Optional[torch.Tensor]
|
||||||
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||||
|
enable_pdl: bool
|
||||||
|
Whether to enable `programmatic dependent launch
|
||||||
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
output: torch.Tensor
|
||||||
|
Normalized tensor, shape (batch_size, hidden_size).
|
||||||
|
"""
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
|
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def fused_add_rmsnorm(
|
def fused_add_rmsnorm(
|
||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
enable_pdl: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
|
r"""Fused add root mean square normalization.
|
||||||
|
|
||||||
|
Step 1:
|
||||||
|
``residual[i] += input[i]``
|
||||||
|
|
||||||
|
Step 2:
|
||||||
|
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input: torch.Tensor
|
||||||
|
Input tensor, shape (batch_size, hidden_size).
|
||||||
|
residual: torch.Tensor
|
||||||
|
Residual tensor, shape (batch_size, hidden_size).
|
||||||
|
weight: torch.Tensor
|
||||||
|
Weight tensor, shape (hidden_size,).
|
||||||
|
eps: float
|
||||||
|
Epsilon for numerical stability.
|
||||||
|
enable_pdl: bool
|
||||||
|
Whether to enable `programmatic dependent launch
|
||||||
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||||
|
"""
|
||||||
|
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
|
||||||
|
input, residual, weight, eps, enable_pdl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gemma_rmsnorm(
|
def gemma_rmsnorm(
|
||||||
@@ -29,20 +81,68 @@ def gemma_rmsnorm(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
out: Optional[torch.Tensor] = None,
|
out: Optional[torch.Tensor] = None,
|
||||||
|
enable_pdl: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
r"""Gemma-style root mean square normalization.
|
||||||
|
|
||||||
|
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input: torch.Tensor
|
||||||
|
Input tensor, shape (batch_size, hidden_size).
|
||||||
|
weight: torch.Tensor
|
||||||
|
Weight tensor, shape (hidden_size,).
|
||||||
|
eps: float
|
||||||
|
Epsilon for numerical stability.
|
||||||
|
out: Optional[torch.Tensor]
|
||||||
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||||
|
enable_pdl: bool
|
||||||
|
Whether to enable `programmatic dependent launch
|
||||||
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
output: torch.Tensor
|
||||||
|
Gemma Normalized tensor, shape (batch_size, hidden_size).
|
||||||
|
"""
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(
|
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||||
out, input, weight, eps, get_cuda_stream()
|
|
||||||
)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def gemma_fused_add_rmsnorm(
|
def gemma_fused_add_rmsnorm(
|
||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
enable_pdl: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
r"""Gemma-style fused add root mean square normalization.
|
||||||
|
|
||||||
|
Step 1:
|
||||||
|
``residual[i] += input[i]``
|
||||||
|
|
||||||
|
Step 2:
|
||||||
|
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input: torch.Tensor
|
||||||
|
Input tensor, shape (batch_size, hidden_size).
|
||||||
|
residual: torch.Tensor
|
||||||
|
Residual tensor, shape (batch_size, hidden_size).
|
||||||
|
weight: torch.Tensor
|
||||||
|
Weight tensor, shape (hidden_size,).
|
||||||
|
eps: float
|
||||||
|
Epsilon for numerical stability.
|
||||||
|
enable_pdl: bool
|
||||||
|
Whether to enable `programmatic dependent launch
|
||||||
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||||
|
"""
|
||||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||||
input, residual, weight, eps, get_cuda_stream()
|
input, residual, weight, eps, enable_pdl
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,7 @@ def _top_k_renorm_probs_internal(
|
|||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
renorm_probs = torch.empty_like(probs)
|
renorm_probs = torch.empty_like(probs)
|
||||||
torch.ops.sgl_kernel.top_k_renorm_probs.default(
|
torch.ops.sgl_kernel.top_k_renorm_probs.default(
|
||||||
probs,
|
probs, renorm_probs, maybe_top_k_arr, top_k_val
|
||||||
renorm_probs,
|
|
||||||
maybe_top_k_arr,
|
|
||||||
top_k_val,
|
|
||||||
get_cuda_stream(),
|
|
||||||
)
|
)
|
||||||
return renorm_probs
|
return renorm_probs
|
||||||
|
|
||||||
@@ -26,6 +22,30 @@ def top_k_renorm_probs(
|
|||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
top_k: Union[torch.Tensor, int],
|
top_k: Union[torch.Tensor, int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
|
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
probs: torch.Tensor
|
||||||
|
Probabilities, shape ``(batch_size, num_classes)``.
|
||||||
|
top_k: Union[torch.Tensor, int]
|
||||||
|
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
|
||||||
|
for re-normalizing probabilities, should be in ``(0, num_classes)``.
|
||||||
|
If a scalar, the same threshold is used for all requests.
|
||||||
|
If a tensor, each request has its own threshold.
|
||||||
|
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
renorm_probs: torch.Tensor
|
||||||
|
Renormalized probabilities, shape ``(batch_size, num_classes)``.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
|
||||||
|
``top_k_sampling_from_probs``.
|
||||||
|
"""
|
||||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||||
|
|
||||||
|
|
||||||
@@ -41,11 +61,7 @@ def _top_p_renorm_probs_internal(
|
|||||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
renorm_probs = torch.empty_like(probs)
|
renorm_probs = torch.empty_like(probs)
|
||||||
torch.ops.sgl_kernel.top_p_renorm_probs.default(
|
torch.ops.sgl_kernel.top_p_renorm_probs.default(
|
||||||
probs,
|
probs, renorm_probs, maybe_top_p_arr, top_p_val
|
||||||
renorm_probs,
|
|
||||||
maybe_top_p_arr,
|
|
||||||
top_p_val,
|
|
||||||
get_cuda_stream(),
|
|
||||||
)
|
)
|
||||||
return renorm_probs
|
return renorm_probs
|
||||||
|
|
||||||
@@ -54,6 +70,32 @@ def top_p_renorm_probs(
|
|||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
top_p: Union[torch.Tensor, float],
|
top_p: Union[torch.Tensor, float],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
|
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
probs: torch.Tensor
|
||||||
|
Probabilities, shape ``(batch_size, num_classes)``.
|
||||||
|
top_p: Union[torch.Tensor, float]
|
||||||
|
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
|
||||||
|
re-normalizing probabilities, should be in ``(0, 1)``.
|
||||||
|
If a scalar, the same threshold is used for all requests.
|
||||||
|
If a tensor, each request has its own threshold.
|
||||||
|
We mask out the probabilities less than `threshold` where the cumulative sum
|
||||||
|
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
renorm_probs: torch.Tensor
|
||||||
|
Renormalized probabilities, shape ``(batch_size, num_classes)``.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
|
||||||
|
``top_p_sampling_from_probs``.
|
||||||
|
|
||||||
|
"""
|
||||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||||
|
|
||||||
|
|
||||||
@@ -62,93 +104,187 @@ top_p_renorm_prob = top_p_renorm_probs
|
|||||||
|
|
||||||
def _top_p_sampling_from_probs_internal(
|
def _top_p_sampling_from_probs_internal(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
indices: Optional[torch.Tensor],
|
||||||
maybe_top_p_arr: Optional[torch.Tensor],
|
maybe_top_p_arr: Optional[torch.Tensor],
|
||||||
top_p_val: float,
|
top_p_val: float,
|
||||||
deterministic: bool,
|
deterministic: bool,
|
||||||
|
generator: Optional[torch.Generator],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
with probs.device as device:
|
with probs.device as device:
|
||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
uniform_samples = uniform_samples.float()
|
|
||||||
maybe_top_p_arr = (
|
maybe_top_p_arr = (
|
||||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
|
||||||
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
|
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
|
||||||
samples,
|
samples,
|
||||||
success,
|
indices,
|
||||||
maybe_top_p_arr,
|
maybe_top_p_arr,
|
||||||
top_p_val,
|
top_p_val,
|
||||||
deterministic,
|
deterministic,
|
||||||
get_cuda_stream(),
|
generator,
|
||||||
)
|
)
|
||||||
return samples, success
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def top_p_sampling_from_probs(
|
def top_p_sampling_from_probs(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
top_p: Union[torch.Tensor, float],
|
top_p: Union[torch.Tensor, float],
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
check_nan: bool = False,
|
check_nan: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
|
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
|
||||||
|
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||||
|
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||||
|
|
||||||
|
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||||
|
which is more efficient than the naive implementation that launches a series of kernels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
probs: torch.Tensor
|
||||||
|
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||||
|
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||||
|
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||||
|
probability distributions.
|
||||||
|
top_p: Union[torch.Tensor, float]
|
||||||
|
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||||
|
If a scalar, the same threshold is used for all requests.
|
||||||
|
If a tensor, each request has its own threshold.
|
||||||
|
indices: Optional[torch.Tensor]
|
||||||
|
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||||
|
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||||
|
This allows reusing the same probability distribution for multiple outputs.
|
||||||
|
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||||
|
deterministic: bool
|
||||||
|
Whether to use deterministic kernel implementation, default is ``True``.
|
||||||
|
generator: Optional[torch.Generator]
|
||||||
|
A random number generator for the operation.
|
||||||
|
check_nan: bool
|
||||||
|
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
samples: torch.Tensor
|
||||||
|
Sampled categories, shape ``(batch_size,)``.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
This function expects float32 inputs, and the output is int32.
|
||||||
|
|
||||||
|
"""
|
||||||
if check_nan:
|
if check_nan:
|
||||||
if torch.any(torch.isnan(probs)):
|
if torch.any(torch.isnan(probs)):
|
||||||
raise ValueError("Input probs contains NaN.")
|
raise ValueError("Input probs contains NaN.")
|
||||||
return _top_p_sampling_from_probs_internal(
|
return _top_p_sampling_from_probs_internal(
|
||||||
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _top_k_top_p_sampling_from_probs_internal(
|
def _top_k_top_p_sampling_from_probs_internal(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
indices: Optional[torch.Tensor],
|
||||||
maybe_top_k_arr: Optional[torch.Tensor],
|
maybe_top_k_arr: Optional[torch.Tensor],
|
||||||
top_k_val: int,
|
top_k_val: int,
|
||||||
maybe_top_p_arr: Optional[torch.Tensor],
|
maybe_top_p_arr: Optional[torch.Tensor],
|
||||||
top_p_val: float,
|
top_p_val: float,
|
||||||
deterministic: bool,
|
deterministic: bool,
|
||||||
|
generator: Optional[torch.Generator],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
with probs.device as device:
|
with probs.device as device:
|
||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
uniform_samples = uniform_samples.float()
|
|
||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
maybe_top_p_arr = (
|
maybe_top_p_arr = (
|
||||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
|
||||||
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
|
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
|
||||||
samples,
|
samples,
|
||||||
success,
|
indices,
|
||||||
maybe_top_k_arr,
|
maybe_top_k_arr,
|
||||||
top_k_val,
|
top_k_val,
|
||||||
maybe_top_p_arr,
|
maybe_top_p_arr,
|
||||||
top_p_val,
|
top_p_val,
|
||||||
deterministic,
|
deterministic,
|
||||||
get_cuda_stream(),
|
generator,
|
||||||
)
|
)
|
||||||
return samples, success
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_sampling_from_probs(
|
def top_k_top_p_sampling_from_probs(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
top_k: Union[torch.Tensor, int],
|
top_k: Union[torch.Tensor, int],
|
||||||
top_p: Union[torch.Tensor, float],
|
top_p: Union[torch.Tensor, float],
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
filter_apply_order: str = "top_k_first",
|
filter_apply_order: str = "top_k_first",
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
check_nan: bool = False,
|
check_nan: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
|
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||||
|
|
||||||
|
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||||
|
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||||
|
|
||||||
|
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||||
|
which is more efficient than the naive implementation that launches a series of kernels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
probs: torch.Tensor
|
||||||
|
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||||
|
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||||
|
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||||
|
probability distributions.
|
||||||
|
top_k: Union[torch.Tensor, int]
|
||||||
|
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
|
||||||
|
If a scalar, the same threshold is used for all requests.
|
||||||
|
If a tensor, each request has its own threshold.
|
||||||
|
top_p: Union[torch.Tensor, float]
|
||||||
|
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||||
|
If a scalar, the same threshold is used for all requests.
|
||||||
|
If a tensor, each request has its own threshold.
|
||||||
|
indices: Optional[torch.Tensor]
|
||||||
|
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||||
|
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||||
|
This allows reusing the same probability distribution for multiple outputs.
|
||||||
|
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||||
|
filter_apply_order: str
|
||||||
|
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
|
||||||
|
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||||
|
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
|
||||||
|
deterministic: bool
|
||||||
|
Whether to use deterministic kernel implementation, default is ``True``.
|
||||||
|
generator: Optional[torch.Generator]
|
||||||
|
A random number generator for the operation.
|
||||||
|
check_nan: bool
|
||||||
|
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
samples: torch.Tensor
|
||||||
|
Sampled categories, shape ``(batch_size,)``.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
This function expects float32 inputs, and the output is int32.
|
||||||
|
|
||||||
|
"""
|
||||||
if filter_apply_order == "top_k_first":
|
if filter_apply_order == "top_k_first":
|
||||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||||
return top_p_sampling_from_probs(
|
return top_p_sampling_from_probs(
|
||||||
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
renorm_probs,
|
||||||
|
top_p,
|
||||||
|
indices,
|
||||||
|
deterministic,
|
||||||
|
check_nan=check_nan,
|
||||||
|
generator=generator,
|
||||||
)
|
)
|
||||||
elif filter_apply_order == "joint":
|
elif filter_apply_order == "joint":
|
||||||
if check_nan:
|
if check_nan:
|
||||||
@@ -156,10 +292,11 @@ def top_k_top_p_sampling_from_probs(
|
|||||||
raise ValueError("Input probs contains NaN.")
|
raise ValueError("Input probs contains NaN.")
|
||||||
return _top_k_top_p_sampling_from_probs_internal(
|
return _top_k_top_p_sampling_from_probs_internal(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
indices,
|
||||||
*_to_tensor_scalar_tuple(top_k),
|
*_to_tensor_scalar_tuple(top_k),
|
||||||
*_to_tensor_scalar_tuple(top_p),
|
*_to_tensor_scalar_tuple(top_p),
|
||||||
deterministic,
|
deterministic,
|
||||||
|
generator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||||
@@ -167,44 +304,82 @@ def top_k_top_p_sampling_from_probs(
|
|||||||
|
|
||||||
def _min_p_sampling_from_probs_internal(
|
def _min_p_sampling_from_probs_internal(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
indices: Optional[torch.Tensor],
|
||||||
maybe_min_p_arr: Optional[torch.Tensor],
|
maybe_min_p_arr: Optional[torch.Tensor],
|
||||||
min_p_val: float,
|
min_p_val: float,
|
||||||
deterministic: bool,
|
deterministic: bool,
|
||||||
|
generator: Optional[torch.Generator],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
with probs.device as device:
|
with probs.device as device:
|
||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
uniform_samples = uniform_samples.float()
|
|
||||||
maybe_min_p_arr = (
|
maybe_min_p_arr = (
|
||||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
|
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
|
||||||
samples,
|
samples,
|
||||||
|
indices,
|
||||||
maybe_min_p_arr,
|
maybe_min_p_arr,
|
||||||
min_p_val,
|
min_p_val,
|
||||||
deterministic,
|
deterministic,
|
||||||
get_cuda_stream(),
|
generator,
|
||||||
)
|
)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def min_p_sampling_from_probs(
|
def min_p_sampling_from_probs(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
min_p: Union[torch.Tensor, float],
|
min_p: Union[torch.Tensor, float],
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
check_nan: bool = False,
|
check_nan: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if uniform_samples.dim() == 2:
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
# Take the first row (round) of uniform_samples
|
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
|
||||||
uniform_samples = uniform_samples[0]
|
|
||||||
|
|
||||||
|
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||||
|
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||||
|
|
||||||
|
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||||
|
which is more efficient than the naive implementation that launches a series of kernels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
probs: torch.Tensor
|
||||||
|
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||||
|
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||||
|
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||||
|
probability distributions.
|
||||||
|
min_p: Union[torch.Tensor, float]
|
||||||
|
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
|
||||||
|
If a scalar, the same threshold is used for all requests.
|
||||||
|
If a tensor, each request has its own threshold.
|
||||||
|
indices: Optional[torch.Tensor]
|
||||||
|
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||||
|
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||||
|
This allows reusing the same probability distribution for multiple outputs.
|
||||||
|
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||||
|
deterministic: bool
|
||||||
|
Whether to use deterministic kernel implementation, default is ``True``.
|
||||||
|
generator: Optional[torch.Generator]
|
||||||
|
A random number generator for the operation.
|
||||||
|
check_nan: bool
|
||||||
|
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
samples: torch.Tensor
|
||||||
|
Sampled categories, shape ``(batch_size,)``.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
This function expects float32 inputs, and the output is int32.
|
||||||
|
"""
|
||||||
if check_nan:
|
if check_nan:
|
||||||
if torch.any(torch.isnan(probs)):
|
if torch.any(torch.isnan(probs)):
|
||||||
raise ValueError("Input probs contains NaN.")
|
raise ValueError("Input probs contains NaN.")
|
||||||
return _min_p_sampling_from_probs_internal(
|
return _min_p_sampling_from_probs_internal(
|
||||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import sgl_kernel
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||||
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
@@ -16,14 +16,13 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
|||||||
k = int(vocab_size * 0.1)
|
k = int(vocab_size * 0.1)
|
||||||
else:
|
else:
|
||||||
raise ValueError("p not recognized")
|
raise ValueError("p not recognized")
|
||||||
max_top_k_trails = 32
|
|
||||||
eps = 1e-4
|
eps = 1e-4
|
||||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||||
# top-p mask
|
# top-p mask
|
||||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||||
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||||
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
||||||
# top-k mask
|
# top-k mask
|
||||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||||
@@ -31,40 +30,35 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
|||||||
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||||
# overall mask
|
# overall mask
|
||||||
mask = torch.minimum(mask_top_p, mask_top_k)
|
mask = torch.minimum(mask_top_p, mask_top_k)
|
||||||
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
|
top_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
||||||
0
|
top_k_tensor = torch.full((batch_size,), k, device="cuda:0")
|
||||||
)
|
|
||||||
top_p_tensor = torch.full((batch_size,), p).to(0)
|
|
||||||
top_k_tensor = torch.full((batch_size,), k).to(0)
|
|
||||||
|
|
||||||
num_trails = 1000
|
num_trails = 1000
|
||||||
for _ in range(num_trails):
|
for _ in range(num_trails):
|
||||||
uniform_samples.uniform_()
|
samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||||
samples, success = sgl_kernel.top_k_top_p_sampling_from_probs(
|
|
||||||
normalized_prob,
|
normalized_prob,
|
||||||
uniform_samples,
|
|
||||||
top_k_tensor,
|
top_k_tensor,
|
||||||
top_p_tensor,
|
top_p_tensor,
|
||||||
filter_apply_order="joint",
|
filter_apply_order="joint",
|
||||||
)
|
)
|
||||||
assert torch.all(success)
|
|
||||||
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
||||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
||||||
torch.arange(batch_size), samples
|
torch.arange(batch_size), samples
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||||
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
||||||
def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
||||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
torch.manual_seed(42)
|
||||||
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||||
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
|
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
|
||||||
renorm_prob_ground_truth = normalized_prob
|
renorm_prob_ground_truth = normalized_prob.clone()
|
||||||
renorm_prob_ground_truth[mask == 0] = 0
|
renorm_prob_ground_truth[mask == 0] = 0
|
||||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||||
dim=-1, keepdim=True
|
dim=-1, keepdim=True
|
||||||
@@ -79,56 +73,54 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||||
@pytest.mark.parametrize("k", [10, 100, 500])
|
@pytest.mark.parametrize("k", [10, 100, 500])
|
||||||
def test_top_k_renorm_probs(batch_size, vocab_size, k):
|
def test_top_k_renorm_probs(batch_size, vocab_size, k):
|
||||||
if k > vocab_size:
|
if k > vocab_size:
|
||||||
pytest.skip("k should be less than vocab_size")
|
pytest.skip("k should be less than vocab_size")
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||||
pivot = sorted_prob[:, k - 1]
|
pivot = sorted_prob[:, k - 1]
|
||||||
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||||
renorm_prob_ground_truth = normalized_prob
|
renorm_prob_ground_truth = normalized_prob.clone()
|
||||||
renorm_prob_ground_truth[mask == 0] = 0
|
renorm_prob_ground_truth[mask == 0] = 0
|
||||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||||
dim=-1, keepdim=True
|
dim=-1, keepdim=True
|
||||||
)
|
)
|
||||||
|
|
||||||
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
|
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
|
||||||
torch.testing.assert_close(
|
for i in range(batch_size):
|
||||||
renorm_prob_ground_truth,
|
torch.testing.assert_close(
|
||||||
renorm_prob,
|
renorm_prob_ground_truth[i],
|
||||||
rtol=1e-3,
|
renorm_prob[i],
|
||||||
atol=1e-3,
|
rtol=1e-3,
|
||||||
)
|
atol=1e-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||||
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
|
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
|
||||||
def test_min_p_sampling(batch_size, vocab_size, p):
|
def test_min_p_sampling(batch_size, vocab_size, p):
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||||
# scale min-p
|
# scale min-p
|
||||||
top_probs = sorted_prob[:, -1].unsqueeze(-1)
|
top_probs = sorted_prob[:, -1].unsqueeze(-1)
|
||||||
scaled_p = p * top_probs
|
scaled_p = p * top_probs
|
||||||
# min-p mask
|
# min-p mask
|
||||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||||
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
|
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
|
||||||
uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0)
|
min_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
||||||
min_p_tensor = torch.full((batch_size,), p).to(0)
|
|
||||||
|
|
||||||
num_trails = 1000
|
num_trails = 1000
|
||||||
for _ in range(num_trails):
|
for _ in range(num_trails):
|
||||||
uniform_samples.uniform_()
|
|
||||||
samples = sgl_kernel.min_p_sampling_from_probs(
|
samples = sgl_kernel.min_p_sampling_from_probs(
|
||||||
normalized_prob,
|
normalized_prob,
|
||||||
uniform_samples,
|
|
||||||
min_p_tensor,
|
min_p_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -136,6 +128,10 @@ def test_min_p_sampling(batch_size, vocab_size, p):
|
|||||||
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
|
||||||
|
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user