[Refactor] Reducing code duplication across FP8 CUDA quantization kernels (#4163)
This commit is contained in:
@@ -1,13 +1,12 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
from typing import Tuple
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||||
|
|
||||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
|||||||
@@ -40,9 +40,6 @@ def calculate_diff(batch_size: int, seq_len: int):
|
|||||||
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
|
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
|
||||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||||
|
|
||||||
print(f"Scale difference: {scale_diff}")
|
|
||||||
print(f"Output difference: {output_diff}")
|
|
||||||
|
|
||||||
if torch.allclose(
|
if torch.allclose(
|
||||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
|
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
|
||||||
|
|||||||
@@ -7,38 +7,6 @@
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
|
||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
|
||||||
#else
|
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
|
||||||
|
|
||||||
#include "amd/quant_utils.cuh"
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
|
||||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
|
||||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
|
||||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|
||||||
float old;
|
|
||||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
|
||||||
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
|
||||||
return old;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
|
||||||
return max_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s,
|
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s,
|
||||||
const int64_t num_elements) {
|
const int64_t num_elements) {
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cub/block/block_reduce.cuh>
|
#include <cub/block/block_reduce.cuh>
|
||||||
@@ -7,31 +6,6 @@
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
|
||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
|
||||||
#else
|
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
|
||||||
|
|
||||||
#include "amd/quant_utils.cuh"
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
|
||||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
|
||||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
|
||||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
|
|
||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
|
||||||
return max_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
|
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
|
||||||
float* __restrict__ output_s, const int64_t hidden_dim,
|
float* __restrict__ output_s, const int64_t hidden_dim,
|
||||||
|
|||||||
@@ -95,3 +95,33 @@ inline int getSMVersion() {
|
|||||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
||||||
|
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
|
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||||
|
#else
|
||||||
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
|
|
||||||
|
#include "amd/quant_utils.cuh"
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||||
|
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
|
float old;
|
||||||
|
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||||
|
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float warpReduceMax(float max_value) {
|
||||||
|
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
|
||||||
|
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
|
||||||
|
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
|
||||||
|
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
|
||||||
|
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
||||||
|
return max_value;
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user