[Refactor] Reducing code duplication across FP8 CUDA quantization kernels (#4163)

This commit is contained in:
Stefan He
2025-03-06 22:58:52 -08:00
committed by GitHub
parent c7f254468f
commit 95085d65e9
5 changed files with 32 additions and 64 deletions

View File

@@ -1,13 +1,12 @@
import itertools import itertools
import math from typing import Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn

View File

@@ -40,9 +40,6 @@ def calculate_diff(batch_size: int, seq_len: int):
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
print(f"Scale difference: {scale_diff}")
print(f"Output difference: {output_diff}")
if torch.allclose( if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):

View File

@@ -7,38 +7,6 @@
#include "utils.h" #include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}
template <typename T> template <typename T>
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s,
const int64_t num_elements) { const int64_t num_elements) {

View File

@@ -1,5 +1,4 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath> #include <cmath>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
@@ -7,31 +6,6 @@
#include "utils.h" #include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}
template <typename T> template <typename T>
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, __global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s, const int64_t hidden_dim, float* __restrict__ output_s, const int64_t hidden_dim,

View File

@@ -95,3 +95,33 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y)) #define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz;
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}