[Test][BugFix] Fix torch.rand usage in triton penalty test (#6680)
### What this PR does / why we need it?
This PR fixes a `TypeError` in
`tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py`
that was causing nightly test failures. The `torch.rand()` function was
being called with the `device` string as a positional argument, which is
incorrect. This has been corrected to use the `device` keyword argument.
Fixes #
### Does this PR introduce _any_ user-facing change?
No, this change only affects a test file.
### How was this patch tested?
CI is expected to pass with this fix.
- vLLM version: v0.15.0
- vLLM main:
13397841ab
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -127,22 +127,22 @@ def create_test_data(
|
|||||||
repetiton_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
repetiton_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
if torch.rand(1) > 0.3:
|
if torch.rand(1) > 0.3:
|
||||||
repetiton_penalty[i] = torch.rand(1, device).item() * 0.8 + 0.6
|
repetiton_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
|
||||||
|
|
||||||
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
if torch.rand(1) > 0.5:
|
if torch.rand(1) > 0.5:
|
||||||
frequency_penalty[i] = torch.rand(1, device).item() * 0.2
|
frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2
|
||||||
|
|
||||||
presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
if torch.rand(1) > 0.5:
|
if torch.rand(1) > 0.5:
|
||||||
presence_penalty[i] = torch.rand(1, device).item() * 0.2
|
presence_penalty[i] = torch.rand(1, device=device).item() * 0.2
|
||||||
|
|
||||||
temperature = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
temperature = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_reqs):
|
||||||
if torch.rand(1) > 0.2:
|
if torch.rand(1) > 0.2:
|
||||||
presence_penalty[i] = torch.rand(1, device).item() * 1.8 + 0.2
|
presence_penalty[i] = torch.rand(1, device=device).item() * 1.8 + 0.2
|
||||||
|
|
||||||
idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32)
|
idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user