use warp shuffle style reduce and flashinfer vectorize (#3628)

This commit is contained in:
Xiaoyu Zhang
2025-02-19 20:53:51 +08:00
committed by GitHub
parent fe0673f1cc
commit 55a7ec388f
2 changed files with 48 additions and 42 deletions

View File

@@ -186,7 +186,7 @@ configs = list(itertools.product(batch_size_range, seq_len_range, group_size_ran
def benchmark(batch_size, seq_len, group_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_dim = group_size * 2
hidden_dim = 7168
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)