[sgl-kernel] per token group quant support COLUMN MAJOR (#4817)

This commit is contained in:
Xiaoyu Zhang
2025-04-03 09:29:59 +08:00
committed by GitHub
parent 31da75abed
commit 2c8fd99363
3 changed files with 252 additions and 80 deletions

View File

@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(
def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
device = torch.device("cuda")
hidden_dim = group_size * 2
hidden_dim = 7168
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
x.clone(), group_size, dst_dtype
@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
device = torch.device("cuda")
hidden_dim = 7168
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)
quantiles = [0.5, 0.2, 0.8]