use warp shuffle style reduce and flashinfer vectorize (#3628)
This commit is contained in:
@@ -186,7 +186,7 @@ configs = list(itertools.product(batch_size_range, seq_len_range, group_size_ran
|
||||
def benchmark(batch_size, seq_len, group_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
hidden_dim = group_size * 2
|
||||
hidden_dim = 7168
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user