Sync from v0.13
This commit is contained in:
@@ -1,124 +1,49 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
request_id: str,
|
||||
prompt_length: int,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
use_beam_search: bool = False,
|
||||
best_of: int = 1,
|
||||
) -> SequenceGroup:
|
||||
if not block_size:
|
||||
block_size = prompt_length
|
||||
def test_sequence_intermediate_tensors_equal():
|
||||
class AnotherIntermediateTensors(IntermediateTensors):
|
||||
pass
|
||||
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
|
||||
seq_group = SequenceGroup(
|
||||
request_id, [prompt],
|
||||
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
|
||||
time.time(), lora_request)
|
||||
intermediate_tensors = IntermediateTensors({})
|
||||
another_intermediate_tensors = AnotherIntermediateTensors({})
|
||||
assert intermediate_tensors != another_intermediate_tensors
|
||||
|
||||
return seq_group
|
||||
empty_intermediate_tensors_1 = IntermediateTensors({})
|
||||
empty_intermediate_tensors_2 = IntermediateTensors({})
|
||||
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
|
||||
|
||||
different_key_intermediate_tensors_1 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
||||
)
|
||||
difference_key_intermediate_tensors_2 = IntermediateTensors(
|
||||
{"2": torch.zeros([2, 4], dtype=torch.int32)}
|
||||
)
|
||||
assert different_key_intermediate_tensors_1 != difference_key_intermediate_tensors_2
|
||||
|
||||
@pytest.fixture
|
||||
def sample_outputs():
|
||||
return [
|
||||
SequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None) for i in range(5)
|
||||
]
|
||||
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
||||
)
|
||||
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 5], dtype=torch.int32)}
|
||||
)
|
||||
assert (
|
||||
same_key_different_value_intermediate_tensors_1
|
||||
!= same_key_different_value_intermediate_tensors_2
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampler_output(sample_outputs):
|
||||
return SamplerOutput(outputs=sample_outputs)
|
||||
|
||||
|
||||
def test_sampler_output_initialization(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
assert sampler_output.sampled_token_probs is None
|
||||
assert sampler_output.sampled_token_ids is None
|
||||
assert sampler_output.spec_decode_worker_metrics is None
|
||||
|
||||
|
||||
def test_sampler_output_getitem(sampler_output, sample_outputs):
|
||||
assert sampler_output[2] == sample_outputs[2]
|
||||
|
||||
|
||||
def test_sampler_output_setitem(sampler_output):
|
||||
new_output = SequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None)
|
||||
sampler_output[2] = new_output
|
||||
assert sampler_output[2] == new_output
|
||||
|
||||
|
||||
def test_sampler_output_len(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
|
||||
|
||||
def test_sampler_output_eq(sample_outputs):
|
||||
sampler_output1 = SamplerOutput(outputs=sample_outputs)
|
||||
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
|
||||
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
|
||||
assert sampler_output1 == sampler_output2
|
||||
assert sampler_output1 != sampler_output3
|
||||
|
||||
|
||||
def test_sequence_data_prefill():
|
||||
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
|
||||
assert seq_data.get_num_uncomputed_tokens() == 4
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
# advance by 2
|
||||
seq_data.update_num_computed_tokens(2)
|
||||
assert seq_data.get_num_uncomputed_tokens() == 2
|
||||
assert seq_data.get_num_computed_tokens() == 2
|
||||
|
||||
# advance by 1
|
||||
seq_data.update_num_computed_tokens(1)
|
||||
assert seq_data.get_num_uncomputed_tokens() == 1
|
||||
assert seq_data.get_num_computed_tokens() == 3
|
||||
|
||||
# append tokens and reset, simulating recompute
|
||||
seq_data.append_token_id(1, logprob=0.0)
|
||||
seq_data.reset_state_for_recompute()
|
||||
assert seq_data.get_num_uncomputed_tokens() == 5
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
|
||||
|
||||
def test_sequence_group_stage():
|
||||
seq_group = create_dummy_prompt("1", 12)
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(6)
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(5)
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
assert seq_group.is_prefill() is False
|
||||
seqs = seq_group.get_seqs()
|
||||
assert len(seqs) == 1
|
||||
seqs[0].data.append_token_id(1, logprob=0.0)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.reset_state_for_recompute()
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(5)
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(7)
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
assert seq_group.is_prefill() is False
|
||||
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
||||
)
|
||||
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
|
||||
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
||||
)
|
||||
assert (
|
||||
same_key_same_value_intermediate_tensors_1
|
||||
== same_key_same_value_intermediate_tensors_2
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user