[Feat] Update sgl-kernel flashinfer to latest main version (#5500)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
PGFLMG
2025-04-18 03:43:23 +08:00
committed by GitHub
parent f13d65a7ea
commit c08a717c77
8 changed files with 393 additions and 133 deletions

View File

@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
DType threshold_acc) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
DType prob_acc = 0.0;
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
if (aggregate_relu_q_minus_p > u) {
break;
@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);