diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py index 31b48021..8110b116 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py @@ -127,22 +127,22 @@ def create_test_data( repetiton_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32) for i in range(num_reqs): if torch.rand(1) > 0.3: - repetiton_penalty[i] = torch.rand(1, device).item() * 0.8 + 0.6 + repetiton_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6 frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32) for i in range(num_reqs): if torch.rand(1) > 0.5: - frequency_penalty[i] = torch.rand(1, device).item() * 0.2 + frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2 presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32) for i in range(num_reqs): if torch.rand(1) > 0.5: - presence_penalty[i] = torch.rand(1, device).item() * 0.2 + presence_penalty[i] = torch.rand(1, device=device).item() * 0.2 temperature = torch.ones(num_reqs, device=device, dtype=torch.float32) for i in range(num_reqs): if torch.rand(1) > 0.2: - presence_penalty[i] = torch.rand(1, device).item() * 1.8 + 0.2 + presence_penalty[i] = torch.rand(1, device=device).item() * 1.8 + 0.2 idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32)