Sampling penalties memory interface (#2870)

This commit is contained in:
Xiaoyu Zhang
2025-01-13 23:09:00 +08:00
committed by GitHub
parent c1e097ca66
commit d08c77c434
7 changed files with 251 additions and 41 deletions

View File

@@ -222,8 +222,9 @@ configs = list(itertools.product(batch_size_range, seq_length_range))
def benchmark(batch_size, seq_len, provider):
num_experts = 256
block_size = 128
topk = 8
topk_ids = torch.randint(
0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda"
0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda"
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)