[sgl-kernel] per token group quant support COLUMN MAJOR (#4817)
This commit is contained in:
@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(
|
||||
|
||||
def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
|
||||
device = torch.device("cuda")
|
||||
hidden_dim = group_size * 2
|
||||
hidden_dim = 7168
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
|
||||
x = torch.randn(
|
||||
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
|
||||
x.clone(), group_size, dst_dtype
|
||||
@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
|
||||
device = torch.device("cuda")
|
||||
hidden_dim = 7168
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
|
||||
x = torch.randn(
|
||||
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user