diff --git a/tests/e2e/nightly/ops/triton/test_rejection_sampler.py b/tests/e2e/nightly/ops/triton/test_rejection_sampler.py index 86992711..3820fd11 100644 --- a/tests/e2e/nightly/ops/triton/test_rejection_sampler.py +++ b/tests/e2e/nightly/ops/triton/test_rejection_sampler.py @@ -61,7 +61,7 @@ IS_GREEDY = torch.zeros(NUM_TOKENS, dtype=torch.bool, device=DEVICE) @pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS]) @pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS]) @pytest.mark.parametrize("is_greedy", [IS_GREEDY]) -@pytest.mark.parametrize("vocab_size", [BATCH_SIZE]) +@pytest.mark.parametrize("batch_size", [BATCH_SIZE]) @pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN]) @pytest.mark.parametrize("vocab_size", [VOCAB_SIZE]) @torch.inference_mode()