Add typo checker in pre-commit (#6179)
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -114,7 +114,7 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"--expt-extended-lambda"
|
||||
"--threads=32"
|
||||
|
||||
# Supress warnings
|
||||
# Suppress warnings
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ Third-party libraries:
|
||||
|
||||
The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x.
|
||||
|
||||
And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. Thats mean if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.
|
||||
And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. That means if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.
|
||||
|
||||
### Kernel Development
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def _per_token_group_quant_8bit(
|
||||
y_s_ptr,
|
||||
# Stride of input
|
||||
y_stride,
|
||||
# Collums of input
|
||||
# Columns of input
|
||||
N,
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
|
||||
@@ -49,7 +49,7 @@ namespace {
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
|
||||
@@ -718,7 +718,7 @@ void decode_attention_kernel_impl(
|
||||
|
||||
m_prime = m_i;
|
||||
|
||||
// caculate V' <- s_delta @ V + V' * m_delta
|
||||
// calculate V' <- s_delta @ V + V' * m_delta
|
||||
index_gemm_kernel_nn<scalar_t, index_t>(
|
||||
/* A */ s_delta,
|
||||
/* B */ v_buffer + head_id * v_strideH,
|
||||
@@ -925,7 +925,7 @@ void decode_attention_grouped_kernel_impl(
|
||||
m_prime[h] = m_i;
|
||||
}
|
||||
|
||||
// caculate V' <- s_delta @ V + V' * m_delta
|
||||
// calculate V' <- s_delta @ V + V' * m_delta
|
||||
index_gemm_kernel_nn<scalar_t, index_t>(
|
||||
/* A */ s_delta,
|
||||
/* B */ v_buffer + head_kv_id * v_strideH,
|
||||
|
||||
@@ -323,7 +323,7 @@ void extend_attention_kernel_impl(
|
||||
/* ld_src */ v_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// caculate V' <- s_delta @ V + V'
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
@@ -434,7 +434,7 @@ void extend_attention_kernel_impl(
|
||||
/* ld_src */ ve_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// caculate V' <- s_delta @ V + V'
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
|
||||
@@ -79,7 +79,7 @@ void fused_experts_int8_kernel_impl(
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
|
||||
@@ -51,7 +51,7 @@ struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
// oops! 4x4 spills but we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// This is only a pluggin used for flashinfer 0.1.6. The new version does not need it.
|
||||
// This is only a plugin used for flashinfer 0.1.6. The new version does not need it.
|
||||
/*
|
||||
* Copyright (c) 2025 by SGLang team.
|
||||
* Copyright (c) 2025 by FlashInfer team.
|
||||
|
||||
@@ -20,16 +20,16 @@ limitations under the License.
|
||||
#include <torch/library.h>
|
||||
|
||||
/**
|
||||
* Unforunately, the type signatures of the flash_attn ops are not compatible
|
||||
* Unfortunately, the type signatures of the flash_attn ops are not compatible
|
||||
* with the PyTorch library bindings. To get around that we use
|
||||
* `make_pytorch_shim` which creates a lambda that exponses the API using
|
||||
* `make_pytorch_shim` which creates a lambda that exposes the API using
|
||||
* PyTorch compatible types to the types, then converts them to the types
|
||||
* expected by the flash_attn ops. This shims allows us to make minimal changes
|
||||
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
|
||||
*
|
||||
* The `pytorch_library_compatible_type` struct is used to map from the
|
||||
* flash_attn ops types to a PyTorch library compatible one. The main issues is
|
||||
* that the following types are not support by PyTorch libary bindings:
|
||||
* that the following types are not support by PyTorch library bindings:
|
||||
* - `int`
|
||||
* - `float`
|
||||
* - `std::optional<T> &`
|
||||
|
||||
@@ -229,7 +229,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
|
||||
@@ -17,7 +17,7 @@ def is_fa3_supported(device=None) -> bool:
|
||||
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
return (
|
||||
torch.cuda.get_device_capability(device)[0] == 9
|
||||
or torch.cuda.get_device_capability(device)[0] == 8
|
||||
|
||||
@@ -45,10 +45,10 @@ def moe_fused_gate(
|
||||
):
|
||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
||||
# as the group weight to select exerpt groups and then select topk experts within the selected groups
|
||||
# as the group weight to select expert groups and then select topk experts within the selected groups
|
||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
|
||||
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
||||
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
|
||||
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
|
||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||
|
||||
@@ -24,7 +24,7 @@ def is_fa3_supported(device=None) -> bool:
|
||||
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
return (
|
||||
torch.cuda.get_device_capability(device)[0] == 9
|
||||
or torch.cuda.get_device_capability(device)[0] == 8
|
||||
|
||||
@@ -21,7 +21,7 @@ def _per_token_group_quant_fp8(
|
||||
y_s_ptr,
|
||||
# Stride of input
|
||||
y_stride,
|
||||
# Collums of input
|
||||
# Columns of input
|
||||
N,
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
|
||||
Reference in New Issue
Block a user