This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -109,8 +108,7 @@ def bench_kineto(
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([int(re.search(name, line) is not None) for line in prof_lines])
|
||||
== 1
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
@@ -124,7 +122,7 @@ def bench_kineto(
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if re.search(name, line) is not None:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
|
||||
@@ -43,17 +43,11 @@ _is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = True
|
||||
except ImportError:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||
|
||||
enable_sgl_per_token_group_quant_8bit = False
|
||||
from sgl_kernel import (
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
@@ -502,24 +496,9 @@ def sglang_per_token_group_quant_fp8(
|
||||
)
|
||||
|
||||
if x.shape[0] > 0:
|
||||
# Temporary
|
||||
if enable_sgl_per_token_group_quant_8bit:
|
||||
sgl_per_token_group_quant_8bit(
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
scale_ue8m0,
|
||||
fuse_silu_and_mul,
|
||||
masked_m,
|
||||
)
|
||||
else:
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
sgl_per_token_group_quant_fp8(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
# Temporary
|
||||
try:
|
||||
from sgl_kernel import sgl_per_token_group_quant_8bit
|
||||
except ImportError:
|
||||
from sgl_kernel import (
|
||||
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
|
||||
)
|
||||
from sgl_kernel import sgl_per_token_group_quant_int8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8(
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel.test_utils import create_per_token_group_quant_test_data
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
@@ -21,231 +19,78 @@ from sglang.srt.utils import is_hip
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
mode_concentrated = os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated"
|
||||
|
||||
if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")):
|
||||
# configs = [[
|
||||
# 768,
|
||||
# 16384,
|
||||
# 128,
|
||||
# None,
|
||||
# fp8_type_,
|
||||
# dict(
|
||||
# column_major_scales=True,
|
||||
# scale_tma_aligned=True,
|
||||
# scale_ue8m0=True,
|
||||
# fuse_silu_and_mul=False,
|
||||
# masked_layout_mode=None,
|
||||
# ),
|
||||
# ]]
|
||||
configs = [
|
||||
[
|
||||
768 * 8,
|
||||
2048,
|
||||
128,
|
||||
48,
|
||||
fp8_type_,
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
# masked_layout_mode=None,
|
||||
masked_layout_mode="balanced",
|
||||
# masked_layout_mode="extreme",
|
||||
),
|
||||
]
|
||||
]
|
||||
elif mode_concentrated:
|
||||
configs = list(
|
||||
itertools.product(
|
||||
[768],
|
||||
[1536, 7168, 16384],
|
||||
[128],
|
||||
[None],
|
||||
[fp8_type_],
|
||||
[
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
) + list(
|
||||
itertools.product(
|
||||
[768 * 8],
|
||||
[2048],
|
||||
[128],
|
||||
[48],
|
||||
[fp8_type_],
|
||||
[
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="balanced",
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="imbalanced",
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="extreme",
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
configs = list(
|
||||
itertools.product(
|
||||
[1, 4, 16, 64, 256, 768, 2048, 8192, 16384],
|
||||
[1536, 7168, 16384],
|
||||
[128],
|
||||
[None],
|
||||
[fp8_type_],
|
||||
[
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
) + list(
|
||||
itertools.product(
|
||||
[1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
|
||||
[2048],
|
||||
[128],
|
||||
[8, 16, 32, 48],
|
||||
[fp8_type_],
|
||||
[
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="balanced",
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="imbalanced",
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="extreme",
|
||||
),
|
||||
],
|
||||
)
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
flags_range = [
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
configs = list(
|
||||
itertools.product(
|
||||
num_tokens_range,
|
||||
hidden_dim_range,
|
||||
group_size_range,
|
||||
dst_dtype_range,
|
||||
flags_range,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=[
|
||||
"num_tokens",
|
||||
"hidden_dim",
|
||||
"group_size",
|
||||
"num_ranks",
|
||||
"dst_dtype",
|
||||
"flags",
|
||||
],
|
||||
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["triton", "sglang"],
|
||||
# Triton has multi kernels and we only report the time for the core one
|
||||
line_names=["Triton (Inaccurate)", "SGL Kernel"],
|
||||
line_names=["Triton", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-group-quant-8bit-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(
|
||||
num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider
|
||||
):
|
||||
print(
|
||||
f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}"
|
||||
)
|
||||
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
||||
if flags["scale_ue8m0"] and group_size != 128:
|
||||
return
|
||||
|
||||
x, masked_m = create_per_token_group_quant_test_data(
|
||||
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
|
||||
)
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
||||
|
||||
fn, kernel_names = {
|
||||
"triton": (
|
||||
triton_per_token_group_quant_8bit,
|
||||
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel",
|
||||
),
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
||||
"sglang": (
|
||||
sglang_per_token_group_quant_8bit,
|
||||
"per_token_group_quant_8bit_kernel",
|
||||
),
|
||||
}[provider]
|
||||
bench_fn = lambda: fn(
|
||||
x=x,
|
||||
masked_m=masked_m,
|
||||
group_size=group_size,
|
||||
dst_dtype=dst_dtype,
|
||||
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
|
||||
)
|
||||
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
|
||||
|
||||
time_s = bench_kineto(
|
||||
bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30
|
||||
)
|
||||
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
|
||||
return time_s * 1e6
|
||||
|
||||
|
||||
|
||||
@@ -121,9 +121,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit);
|
||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float int8_min, float int8_max) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
|
||||
|
||||
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
|
||||
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
|
||||
|
||||
@@ -1,396 +1,119 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <int THREADS_PER_SUBWARP>
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
||||
|
||||
static_assert(
|
||||
(THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1,
|
||||
"THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16");
|
||||
|
||||
if constexpr (THREADS_PER_SUBWARP >= 16) {
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
}
|
||||
if constexpr (THREADS_PER_SUBWARP >= 8) {
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
}
|
||||
if constexpr (THREADS_PER_SUBWARP >= 4) {
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
}
|
||||
if constexpr (THREADS_PER_SUBWARP >= 2) {
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
}
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float silu(const float& val) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
float half = 0.5f * val;
|
||||
float t = __tanhf(half);
|
||||
return half * (1.0f + t);
|
||||
#else
|
||||
return val / (1.0f + __expf(-val));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ float2 fmul2_rn(float2 a, float2 b) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
return __fmul2_rn(a, b);
|
||||
#else
|
||||
float2 result;
|
||||
result.x = a.x * b.x;
|
||||
result.y = a.y * b.y;
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Copied and modified from DeepEP
|
||||
__forceinline__ __device__ float fast_pow2(int x) {
|
||||
// We can ensure `-126 <= x and x <= 127`
|
||||
uint32_t bits_x = (x + 127) << 23;
|
||||
return *reinterpret_cast<float*>(&bits_x);
|
||||
}
|
||||
|
||||
// Copied and modified from DeepEP
|
||||
__forceinline__ __device__ int fast_log2_ceil(float x) {
|
||||
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
|
||||
auto exp_x = (bits_x >> 23) & 0xff;
|
||||
auto man_bits = bits_x & ((1 << 23) - 1);
|
||||
return exp_x - 127 + (man_bits != 0);
|
||||
}
|
||||
|
||||
// Copied and modified from DeepEP
|
||||
template <bool ROUND_SCALE, typename dtype_info>
|
||||
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) {
|
||||
constexpr float MAX_8BIT_INV = 1.0f / dtype_info::MAX;
|
||||
if constexpr (ROUND_SCALE) {
|
||||
auto exp_scale_inv = fast_log2_ceil(amax * MAX_8BIT_INV);
|
||||
scale = fast_pow2(-exp_scale_inv);
|
||||
scale_inv = fast_pow2(exp_scale_inv);
|
||||
} else {
|
||||
scale_inv = amax * MAX_8BIT_INV;
|
||||
scale = dtype_info::MAX / amax;
|
||||
}
|
||||
}
|
||||
|
||||
// Copied and modified from DeepEP
|
||||
template <bool SCALE_UE8M0, typename OUT_DTYPE_T = std::conditional_t<SCALE_UE8M0, uint8_t, float>>
|
||||
__forceinline__ __device__ OUT_DTYPE_T extract_required_scale_format(float value) {
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_global(const int4* ptr, const int4& value) {
|
||||
asm volatile(
|
||||
"st.global.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_global_nc(const int4* ptr) {
|
||||
int4 ret;
|
||||
asm volatile("ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
|
||||
: "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct DtypeInfo;
|
||||
|
||||
template <>
|
||||
struct DtypeInfo<int8_t> {
|
||||
static constexpr float MIN = -128;
|
||||
static constexpr float MAX = 127;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DtypeInfo<c10::Float8_e4m3fn> {
|
||||
static constexpr float MIN = -448;
|
||||
static constexpr float MAX = 448;
|
||||
};
|
||||
|
||||
template <bool FUSE_SILU_AND_MUL>
|
||||
__device__ __forceinline__ int compute_input_group_start_offset(
|
||||
int expert_idx,
|
||||
int token_idx,
|
||||
int hidden_dim_group_idx,
|
||||
int hidden_size,
|
||||
int num_tokens_per_expert,
|
||||
int group_size) {
|
||||
return expert_idx * num_tokens_per_expert * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) +
|
||||
token_idx * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + hidden_dim_group_idx * group_size;
|
||||
}
|
||||
|
||||
constexpr float LOCAL_ABSMAX_ABS = 1e-10;
|
||||
constexpr uint32_t INPUT_PRIMARY_VEC_NUM_BYTES = 32;
|
||||
|
||||
struct NaiveScheduler {
|
||||
static void compute_exec_config(
|
||||
int threads_per_subwarp,
|
||||
int num_local_experts,
|
||||
int hidden_dim_num_groups,
|
||||
int num_groups,
|
||||
int& subwarps_per_block,
|
||||
dim3& grid,
|
||||
dim3& block) {
|
||||
subwarps_per_block = ([=]() -> int {
|
||||
if (num_groups % 16 == 0) {
|
||||
return 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
return 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
return 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
})();
|
||||
grid = dim3(num_groups / subwarps_per_block);
|
||||
block = dim3(subwarps_per_block * threads_per_subwarp);
|
||||
}
|
||||
|
||||
template <bool FUSE_SILU_AND_MUL, int GROUP_SIZE, int THREADS_PER_SUBWARP, typename FUNC>
|
||||
__device__ __forceinline__ static void execute(
|
||||
const int subwarps_per_block,
|
||||
const int hidden_dim_num_groups,
|
||||
const int32_t* masked_m,
|
||||
const int num_tokens_per_expert,
|
||||
FUNC fn) {
|
||||
constexpr int expert_idx = 0;
|
||||
|
||||
const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP;
|
||||
const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * subwarps_per_block;
|
||||
const int64_t group_id = block_group_id + subwarp_id;
|
||||
|
||||
int64_t input_group_start_offset;
|
||||
if constexpr (!FUSE_SILU_AND_MUL) {
|
||||
input_group_start_offset = group_id * GROUP_SIZE;
|
||||
}
|
||||
|
||||
const int token_idx = group_id / hidden_dim_num_groups;
|
||||
// At the hidden_size dimension, we are handling idx-th group
|
||||
const int hidden_dim_group_idx = group_id % hidden_dim_num_groups;
|
||||
|
||||
if constexpr (FUSE_SILU_AND_MUL) {
|
||||
const int hidden_size = hidden_dim_num_groups * GROUP_SIZE;
|
||||
input_group_start_offset = compute_input_group_start_offset<FUSE_SILU_AND_MUL>(
|
||||
expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE);
|
||||
}
|
||||
|
||||
fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset);
|
||||
}
|
||||
};
|
||||
|
||||
struct MaskedLayoutScheduler {
|
||||
// TODO can be dynamically determined (which may be good when num rank is small)
|
||||
static constexpr int TOKEN_DIM_BLOCK_NUM_PER_EXPERT = 1024;
|
||||
static constexpr int SUBWARPS_PER_BLOCK = 16;
|
||||
|
||||
static void compute_exec_config(
|
||||
int threads_per_subwarp,
|
||||
int num_local_experts,
|
||||
int hidden_dim_num_groups,
|
||||
int num_groups,
|
||||
int& subwarps_per_block,
|
||||
dim3& grid,
|
||||
dim3& block) {
|
||||
subwarps_per_block = SUBWARPS_PER_BLOCK;
|
||||
TORCH_CHECK(hidden_dim_num_groups % subwarps_per_block == 0);
|
||||
grid = dim3(hidden_dim_num_groups / subwarps_per_block, TOKEN_DIM_BLOCK_NUM_PER_EXPERT, num_local_experts);
|
||||
block = dim3(subwarps_per_block * threads_per_subwarp);
|
||||
}
|
||||
|
||||
template <bool FUSE_SILU_AND_MUL, int GROUP_SIZE, int THREADS_PER_SUBWARP, typename FUNC>
|
||||
__device__ __forceinline__ static void execute(
|
||||
const int subwarps_per_block,
|
||||
const int hidden_dim_num_groups,
|
||||
const int32_t* masked_m,
|
||||
const int num_tokens_per_expert,
|
||||
FUNC fn) {
|
||||
const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP;
|
||||
const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
|
||||
|
||||
const int expert_idx = blockIdx.z;
|
||||
const int token_idx_start = blockIdx.y;
|
||||
|
||||
const int64_t hidden_dim_group_idx = blockIdx.x * SUBWARPS_PER_BLOCK + subwarp_id;
|
||||
|
||||
const int curr_expert_token_num = masked_m[expert_idx];
|
||||
|
||||
for (int token_idx = token_idx_start; token_idx < curr_expert_token_num;
|
||||
token_idx += TOKEN_DIM_BLOCK_NUM_PER_EXPERT) {
|
||||
const int hidden_size = hidden_dim_num_groups * GROUP_SIZE;
|
||||
const int64_t input_group_start_offset = compute_input_group_start_offset<FUSE_SILU_AND_MUL>(
|
||||
expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE);
|
||||
fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename SCHEDULER,
|
||||
int GROUP_SIZE,
|
||||
int THREADS_PER_SUBWARP,
|
||||
typename T,
|
||||
typename DST_DTYPE,
|
||||
bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false,
|
||||
bool FUSE_SILU_AND_MUL = false,
|
||||
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
void* __restrict__ output_q,
|
||||
scale_packed_t* __restrict__ output_s,
|
||||
const int32_t* __restrict__ masked_m,
|
||||
const int subwarps_per_block,
|
||||
const int hidden_dim_num_groups,
|
||||
// TODO can this be removed?
|
||||
const int scale_expert_stride,
|
||||
const int scale_hidden_stride,
|
||||
const int num_tokens_per_expert) {
|
||||
using dst_dtype_info = DtypeInfo<DST_DTYPE>;
|
||||
const int group_size,
|
||||
const int num_groups,
|
||||
const int groups_per_block,
|
||||
const float eps,
|
||||
const float min_8bit,
|
||||
const float max_8bit,
|
||||
const int num_groups_per_row = 0,
|
||||
const int scale_stride = 0) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
SCHEDULER::execute<FUSE_SILU_AND_MUL, GROUP_SIZE, THREADS_PER_SUBWARP>(
|
||||
subwarps_per_block,
|
||||
hidden_dim_num_groups,
|
||||
masked_m,
|
||||
num_tokens_per_expert,
|
||||
[&](const int expert_idx,
|
||||
const int token_idx,
|
||||
const int hidden_dim_group_idx,
|
||||
const int lane_id,
|
||||
const int input_group_start_offset) {
|
||||
constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T);
|
||||
constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4);
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
scale_element_t* scale_output;
|
||||
|
||||
const int offset_num_groups = expert_idx * num_tokens_per_expert * hidden_dim_num_groups +
|
||||
token_idx * hidden_dim_num_groups + hidden_dim_group_idx;
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
const int row_idx = global_group_id / num_groups_per_row;
|
||||
const int col_idx_unpacked = global_group_id % num_groups_per_row;
|
||||
const int col_idx = col_idx_unpacked / num_elems_per_pack;
|
||||
const int pack_idx = col_idx_unpacked % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
static_assert(!SCALE_UE8M0);
|
||||
scale_output = output_s + global_group_id;
|
||||
}
|
||||
|
||||
int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
|
||||
T* input_primary_vec = reinterpret_cast<T*>(input_primary_int4);
|
||||
static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4));
|
||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||
|
||||
int4 input_secondary_int4[INPUT_PRIMARY_INT4_SIZE];
|
||||
T* input_secondary_vec = reinterpret_cast<T*>(input_secondary_int4);
|
||||
static_assert(sizeof(input_secondary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_secondary_int4));
|
||||
const int32_t num_vec_elems = group_size / vec_size;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(group_input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
|
||||
input_primary_int4[j] = ld_global_nc(
|
||||
reinterpret_cast<const int4*>(input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
|
||||
}
|
||||
if constexpr (FUSE_SILU_AND_MUL) {
|
||||
const int secondary_offset = hidden_dim_num_groups * GROUP_SIZE;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
|
||||
input_secondary_int4[j] = ld_global_nc(
|
||||
reinterpret_cast<const int4*>(
|
||||
input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE + secondary_offset) +
|
||||
j);
|
||||
}
|
||||
}
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
float abs_val = fabsf(val);
|
||||
local_absmax = fmaxf(local_absmax, abs_val);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
scale_element_t* scale_output;
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
constexpr int scale_token_stride = 1;
|
||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||
|
||||
const int hidden_idx_packed = hidden_dim_group_idx / num_elems_per_pack;
|
||||
const int pack_idx = hidden_dim_group_idx % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(expert_idx * scale_expert_stride * num_elems_per_pack +
|
||||
hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
|
||||
token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
static_assert(!SCALE_UE8M0);
|
||||
scale_output = output_s + offset_num_groups;
|
||||
}
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f))));
|
||||
}
|
||||
|
||||
// can speed up if too slow
|
||||
if constexpr (IS_COLUMN_MAJOR and SCALE_UE8M0) {
|
||||
const int remainder_num_groups = hidden_dim_num_groups % num_elems_per_pack;
|
||||
if ((remainder_num_groups != 0) and (hidden_dim_group_idx == hidden_dim_num_groups - 1) and
|
||||
(lane_id < num_elems_per_pack - remainder_num_groups)) {
|
||||
const int shift = 1 + lane_id;
|
||||
*(scale_output + shift) = 0;
|
||||
}
|
||||
}
|
||||
// TODO can optimize
|
||||
scale_element_t y_s_quant;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127);
|
||||
} else {
|
||||
y_s_quant = y_s;
|
||||
}
|
||||
|
||||
float local_absmax = LOCAL_ABSMAX_ABS;
|
||||
if (lane_id == 0) {
|
||||
*scale_output = y_s_quant;
|
||||
}
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
|
||||
vec_t input_vec;
|
||||
input_vec.cast_load(group_input + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
|
||||
float val;
|
||||
if constexpr (FUSE_SILU_AND_MUL) {
|
||||
// TODO maybe vectorize
|
||||
T val_lowprec = static_cast<T>(silu(static_cast<float>(input_primary_vec[j]))) * input_secondary_vec[j];
|
||||
val = static_cast<float>(val_lowprec);
|
||||
input_primary_vec[j] = val_lowprec;
|
||||
} else {
|
||||
val = static_cast<float>(input_primary_vec[j]);
|
||||
}
|
||||
|
||||
float abs_val = fabsf(val);
|
||||
local_absmax = fmaxf(local_absmax, abs_val);
|
||||
}
|
||||
|
||||
local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id);
|
||||
|
||||
float y_scale, y_scale_inv;
|
||||
calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
|
||||
float2 y_scale_repeated = {y_scale, y_scale};
|
||||
|
||||
if (lane_id == 0) {
|
||||
*scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv);
|
||||
}
|
||||
|
||||
int4 output_buf;
|
||||
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE));
|
||||
|
||||
if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) {
|
||||
const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf);
|
||||
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t));
|
||||
static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) {
|
||||
float2 inputx2 = {static_cast<float>(input_primary_vec[j]), static_cast<float>(input_primary_vec[j + 1])};
|
||||
float2 outputx2 = fmul2_rn(inputx2, y_scale_repeated);
|
||||
output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
} else {
|
||||
const auto output_buf_ptr = reinterpret_cast<DST_DTYPE*>(&output_buf);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
|
||||
float val = static_cast<float>(input_primary_vec[j]);
|
||||
float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX);
|
||||
output_buf_ptr[j] = DST_DTYPE(q_val);
|
||||
}
|
||||
}
|
||||
|
||||
st_global(
|
||||
reinterpret_cast<int4*>(output_q + offset_num_groups * GROUP_SIZE + lane_id * INPUT_PRIMARY_VEC_SIZE),
|
||||
output_buf);
|
||||
});
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
|
||||
group_output[i * vec_size + j] = DST_DTYPE(q_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_8bit(
|
||||
// vanilla: (num_tokens, hidden_size)
|
||||
// fuse_silu_and_mul: (num_tokens, hidden_size * 2)
|
||||
// fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2)
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
@@ -398,113 +121,120 @@ void sgl_per_token_group_quant_8bit(
|
||||
double eps,
|
||||
double min_8bit,
|
||||
double max_8bit,
|
||||
bool scale_ue8m0,
|
||||
bool fuse_silu_and_mul,
|
||||
const std::optional<torch::Tensor>& masked_m) {
|
||||
bool scale_ue8m0 = false) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(output_q);
|
||||
TORCH_CHECK(input.numel() > 0);
|
||||
|
||||
TORCH_CHECK(std::abs(LOCAL_ABSMAX_ABS - eps) < 1e-13);
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
CHECK_EQ(input.numel() % group_size, 0);
|
||||
const int num_groups = static_cast<int>(input.numel()) / group_size / (fuse_silu_and_mul ? 2 : 1);
|
||||
|
||||
const bool masked_layout = masked_m.has_value();
|
||||
TORCH_CHECK(output_s.dim() == (masked_layout ? 3 : 2));
|
||||
|
||||
const int num_local_experts = masked_layout ? input.size(0) : 1;
|
||||
CHECK_EQ(output_s.dim(), 2);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
const bool is_column_major = output_s.stride(-2) < output_s.stride(-1);
|
||||
const int hidden_dim_num_groups = static_cast<int>(output_q.size(-1)) / group_size;
|
||||
const int num_tokens_per_expert = static_cast<int>(output_q.size(-2));
|
||||
const int scale_expert_stride = masked_layout ? static_cast<int>(output_s.stride(0)) : 0;
|
||||
const int scale_hidden_stride = static_cast<int>(output_s.stride(-1));
|
||||
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
|
||||
const int hidden_dim = input.size(input.dim() - 1);
|
||||
const int num_groups_per_row = hidden_dim / group_size;
|
||||
const int scale_stride = output_s.stride(1);
|
||||
|
||||
#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \
|
||||
do { \
|
||||
int subwarps_per_block; \
|
||||
dim3 grid, block; \
|
||||
SCHEDULER::compute_exec_config( \
|
||||
THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \
|
||||
\
|
||||
per_token_group_quant_8bit_kernel<SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, __VA_ARGS__> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
static_cast<DST_DTYPE*>(output_q.data_ptr()), \
|
||||
static_cast<output_s_dtype*>(output_s.data_ptr()), \
|
||||
static_cast<int32_t*>(masked_m.has_value() ? masked_m->data_ptr() : 0), \
|
||||
subwarps_per_block, \
|
||||
hidden_dim_num_groups, \
|
||||
scale_expert_stride, \
|
||||
scale_hidden_stride, \
|
||||
num_tokens_per_expert); \
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<uint32_t*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
num_groups_per_row, \
|
||||
scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit, \
|
||||
num_groups_per_row, \
|
||||
scale_stride); \
|
||||
} \
|
||||
} else { \
|
||||
assert(!scale_ue8m0); \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), \
|
||||
output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), \
|
||||
group_size, \
|
||||
num_groups, \
|
||||
groups_per_block, \
|
||||
(float)eps, \
|
||||
(float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \
|
||||
do { \
|
||||
constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \
|
||||
TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \
|
||||
\
|
||||
using dst_dtype_info = DtypeInfo<DST_DTYPE>; \
|
||||
CHECK_EQ(dst_dtype_info::MIN, min_8bit); \
|
||||
CHECK_EQ(dst_dtype_info::MAX, max_8bit); \
|
||||
\
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
if (fuse_silu_and_mul) { \
|
||||
if (masked_layout) { \
|
||||
LAUNCH_KERNEL_INNER( \
|
||||
MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
|
||||
} else { \
|
||||
LAUNCH_KERNEL_INNER( \
|
||||
NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
|
||||
} \
|
||||
} else { \
|
||||
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \
|
||||
} \
|
||||
} else { \
|
||||
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \
|
||||
} \
|
||||
} else { \
|
||||
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LAUNCH_KERNEL_OUTER(...) \
|
||||
switch (group_size) { \
|
||||
case 16: \
|
||||
LAUNCH_KERNEL(16, __VA_ARGS__); \
|
||||
break; \
|
||||
case 32: \
|
||||
LAUNCH_KERNEL(32, __VA_ARGS__); \
|
||||
break; \
|
||||
case 64: \
|
||||
LAUNCH_KERNEL(64, __VA_ARGS__); \
|
||||
break; \
|
||||
case 128: \
|
||||
LAUNCH_KERNEL(128, __VA_ARGS__); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported group_size"); \
|
||||
} \
|
||||
while (0)
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||
if (dst_type == at::ScalarType::Char) {
|
||||
LAUNCH_KERNEL_OUTER(scalar_t, int8_t);
|
||||
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||
return true;
|
||||
} else if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL_OUTER(scalar_t, c10::Float8_e4m3fn);
|
||||
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
#undef LAUNCH_KERNEL
|
||||
#undef LAUNCH_KERNEL_INNER
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_int8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double int8_min,
|
||||
double int8_max) {
|
||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
|
||||
}
|
||||
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
torch::Tensor input,
|
||||
torch::Tensor output_q,
|
||||
torch::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double fp8_min,
|
||||
double fp8_max,
|
||||
bool scale_ue8m0) {
|
||||
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
|
||||
@@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Dtype& out_dtype);
|
||||
void scaled_fp4_quant(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
|
||||
void sgl_per_token_group_quant_8bit(
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double min_8bit,
|
||||
double max_8bit,
|
||||
bool scale_ue8m0,
|
||||
bool fuse_silu_and_mul,
|
||||
const std::optional<torch::Tensor>& masked_m);
|
||||
double fp8_min,
|
||||
double fp8_max,
|
||||
bool scale_ue8m0);
|
||||
void sgl_per_token_group_quant_int8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double int8_min,
|
||||
double int8_max);
|
||||
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
void bmm_fp8(
|
||||
|
||||
@@ -58,7 +58,8 @@ from sgl_kernel.gemm import (
|
||||
scaled_fp4_grouped_quant,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_8bit,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_group_quant_int8,
|
||||
sgl_per_token_quant_fp8,
|
||||
shuffle_rows,
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
|
||||
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
|
||||
return output
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_8bit(
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit(
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
scale_ue8m0: bool = False,
|
||||
fuse_silu_and_mul: bool = False,
|
||||
masked_m: Optional[torch.Tensor] = None,
|
||||
scale_ue8m0: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
|
||||
input,
|
||||
output_q,
|
||||
output_s,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
scale_ue8m0,
|
||||
fuse_silu_and_mul,
|
||||
masked_m,
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_int8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
int8_min: float,
|
||||
int8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
|
||||
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
seed = num_tokens * 10000 + hidden_dim
|
||||
gen_cpu = torch.Generator(device="cpu")
|
||||
gen_cpu.manual_seed(seed)
|
||||
gen_cuda = torch.Generator(device="cuda")
|
||||
gen_cuda.manual_seed(seed)
|
||||
|
||||
if flags["fuse_silu_and_mul"]:
|
||||
effective_hidden_dim = hidden_dim * 2
|
||||
else:
|
||||
effective_hidden_dim = hidden_dim
|
||||
del hidden_dim
|
||||
|
||||
if (masked_layout_mode := flags["masked_layout_mode"]) is not None:
|
||||
num_max_dispatch_tokens_per_rank = 768
|
||||
num_global_experts = 288
|
||||
num_local_experts, remainder = divmod(num_global_experts, num_ranks)
|
||||
assert remainder == 0
|
||||
|
||||
# mimic DeepEP low_latency_dispatch output
|
||||
x = torch.randn(
|
||||
num_local_experts,
|
||||
num_max_dispatch_tokens_per_rank * num_ranks,
|
||||
effective_hidden_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=gen_cuda,
|
||||
)
|
||||
|
||||
if masked_layout_mode == "balanced":
|
||||
masked_m = _compute_balanced_split(num_tokens, num_local_experts)
|
||||
elif masked_layout_mode == "imbalanced":
|
||||
masked_m = _compute_imbalanced_split(
|
||||
num_tokens, num_local_experts, gen_cpu=gen_cpu
|
||||
)
|
||||
elif masked_layout_mode == "extreme":
|
||||
masked_m = torch.tensor(
|
||||
[num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
print(f"{masked_layout_mode=} {masked_m=} {x.shape=}")
|
||||
|
||||
masked_m = masked_m.to(device)
|
||||
|
||||
return x, masked_m
|
||||
else:
|
||||
x = torch.randn(
|
||||
num_tokens,
|
||||
effective_hidden_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=gen_cuda,
|
||||
)
|
||||
x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10
|
||||
return x, None
|
||||
|
||||
|
||||
def _compute_balanced_split(total: int, arr_len: int):
|
||||
base = total // arr_len
|
||||
remainder = total % arr_len
|
||||
ans = [base + 1 if i < remainder else base for i in range(arr_len)]
|
||||
assert sum(ans) == total
|
||||
return torch.tensor(ans, dtype=torch.int)
|
||||
|
||||
|
||||
def _compute_imbalanced_split(
|
||||
total: int, arr_len: int, gen_cpu, dtype=torch.int
|
||||
) -> list[int]:
|
||||
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
|
||||
noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3
|
||||
|
||||
noise = noise_raw / noise_raw.sum()
|
||||
ans = (noise * total).round().to(dtype)
|
||||
|
||||
diff = total - ans.sum().item()
|
||||
while diff != 0:
|
||||
idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item()
|
||||
if diff > 0:
|
||||
ans[idx] += 1
|
||||
diff -= 1
|
||||
elif diff < 0 and ans[idx] > 0:
|
||||
ans[idx] -= 1
|
||||
diff += 1
|
||||
|
||||
assert sum(ans) == total
|
||||
return ans
|
||||
|
||||
|
||||
def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor):
|
||||
assert (a.shape == b.shape) and (
|
||||
a.dtype == b.dtype
|
||||
), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}"
|
||||
numel = a.numel()
|
||||
|
||||
if a.dtype == torch.float8_e4m3fn:
|
||||
a_u8 = a.view(torch.uint8)
|
||||
b_u8 = b.view(torch.uint8)
|
||||
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
|
||||
|
||||
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
|
||||
count_tiny_diff = (diff_u8 == 1).sum().item()
|
||||
count_large_diff = (diff_u8 >= 2).sum().item()
|
||||
elif a.dtype == torch.int8:
|
||||
diff = (a.to(torch.int16) - a.to(torch.int16)).abs()
|
||||
count_diff_sign = ((a >= 0) & (b < 0)).sum().item()
|
||||
count_tiny_diff = (diff == 1).sum().item()
|
||||
count_large_diff = (diff >= 2).sum().item()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
assert (
|
||||
(count_diff_sign == 0)
|
||||
and (count_large_diff == 0)
|
||||
and (
|
||||
(count_tiny_diff / numel < 0.005)
|
||||
or ((count_tiny_diff / numel < 0.04) and (numel <= 4096))
|
||||
)
|
||||
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}"
|
||||
@@ -1,199 +1,96 @@
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel.test_utils import (
|
||||
assert_all_close_or_tiny_diff,
|
||||
create_per_token_group_quant_test_data,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
configs = list(
|
||||
itertools.product(
|
||||
[1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens
|
||||
[128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim
|
||||
[16, 32, 64, 128], # group_size
|
||||
[None], # num_ranks
|
||||
[fp8_type_, torch.int8], # dtype
|
||||
[
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=False,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
) + list(
|
||||
itertools.product(
|
||||
[1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
|
||||
# TODO support more
|
||||
[2048],
|
||||
[128],
|
||||
[8, 16, 32, 48],
|
||||
[fp8_type_],
|
||||
[
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode=None,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="balanced",
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="imbalanced",
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
fuse_silu_and_mul=True,
|
||||
masked_layout_mode="extreme",
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs
|
||||
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
|
||||
list(
|
||||
itertools.product(
|
||||
[127, 128, 512, 1024, 4096, 8192], # num_tokens
|
||||
[256, 512, 1024, 2048, 4096], # hidden_dim
|
||||
[8, 16, 32, 64, 128], # group_size
|
||||
# TODO test int8
|
||||
[fp8_type_], # dtype
|
||||
[
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_per_token_group_quant_with_column_major(
|
||||
num_tokens,
|
||||
hidden_dim,
|
||||
group_size,
|
||||
num_ranks,
|
||||
dst_dtype,
|
||||
flags,
|
||||
):
|
||||
print(
|
||||
f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}"
|
||||
)
|
||||
|
||||
arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
|
||||
if flags["scale_ue8m0"] and (arch_major <= 9):
|
||||
pytest.skip("Only Blackwell need ue8m0 fusion")
|
||||
return
|
||||
|
||||
if (flags["scale_ue8m0"] and (group_size != 128)) or (
|
||||
(dst_dtype == torch.int8) and flags["column_major_scales"]
|
||||
):
|
||||
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
|
||||
pytest.skip()
|
||||
return
|
||||
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
|
||||
pytest.skip("scale_ue8m0 only supported on Blackwell")
|
||||
return
|
||||
|
||||
x, masked_m = create_per_token_group_quant_test_data(
|
||||
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
|
||||
)
|
||||
|
||||
# print("hack data!!!")
|
||||
# x = torch.full_like(x, fill_value=100)
|
||||
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
execute_kwargs = dict(
|
||||
x=x,
|
||||
masked_m=masked_m,
|
||||
group_size=group_size,
|
||||
eps=1e-10,
|
||||
dst_dtype=dst_dtype,
|
||||
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
|
||||
**flags,
|
||||
)
|
||||
|
||||
def _postprocess(x_q, x_s):
|
||||
if masked_m is not None:
|
||||
print(f"Mask tokens after {masked_m} to be zero")
|
||||
for i in range(len(masked_m)):
|
||||
x_q[i, masked_m[i] :, :] = 0
|
||||
x_s[i, masked_m[i] :, :] = 0
|
||||
return x_q, x_s
|
||||
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
|
||||
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
|
||||
|
||||
x_q_triton, x_s_triton = _postprocess(
|
||||
*triton_per_token_group_quant_8bit(**execute_kwargs)
|
||||
# torch.set_printoptions(profile="full")
|
||||
# print(f"{x_q_triton=}")
|
||||
# print(f"{x_s_triton=}")
|
||||
# print(f"{x_q_sglang=}")
|
||||
# print(f"{x_s_sglang=}")
|
||||
# torch.set_printoptions(profile="default")
|
||||
|
||||
assert_fp8_all_close(x_q_triton, x_q_sglang)
|
||||
torch.testing.assert_close(
|
||||
x_s_triton.contiguous(),
|
||||
x_s_sglang.contiguous(),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
|
||||
)
|
||||
x_q_sglang, x_s_sglang = _postprocess(
|
||||
*sglang_per_token_group_quant_8bit(**execute_kwargs)
|
||||
)
|
||||
|
||||
try:
|
||||
assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang)
|
||||
torch.testing.assert_close(
|
||||
x_s_triton.contiguous(),
|
||||
x_s_sglang.contiguous(),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
|
||||
)
|
||||
except AssertionError:
|
||||
# torch.set_printoptions(profile="full")
|
||||
print(
|
||||
f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}"
|
||||
)
|
||||
print(f"{x=}")
|
||||
print(f"{masked_m=}")
|
||||
print(f"{x_q_triton=}")
|
||||
print(f"{x_s_triton=}")
|
||||
print(f"{x_q_sglang=}")
|
||||
print(f"{x_s_sglang=}")
|
||||
# torch.set_printoptions(profile="default")
|
||||
|
||||
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
|
||||
# import matplotlib.pyplot as plt
|
||||
#
|
||||
# base_stem = time.time()
|
||||
# for name, value in [
|
||||
# ("x_q", x_q_triton != x_q_sglang),
|
||||
# ("x_s", x_s_triton != x_s_sglang),
|
||||
# ]:
|
||||
# value = value.reshape((-1, value.shape[-1]))
|
||||
# plt.figure(figsize=(20, 20))
|
||||
# plt.imshow((value * 1.0).cpu().numpy())
|
||||
# p = Path(d) / f"{base_stem}_{name}.png"
|
||||
# print(f"Write diff to {p}", flush=True)
|
||||
# plt.savefig(p)
|
||||
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user