Sampling penalties memory interface (#2870)
This commit is contained in:
@@ -222,8 +222,9 @@ configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||
def benchmark(batch_size, seq_len, provider):
|
||||
num_experts = 256
|
||||
block_size = 128
|
||||
topk = 8
|
||||
topk_ids = torch.randint(
|
||||
0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda"
|
||||
0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
Reference in New Issue
Block a user