forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -7,6 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, OrderedDict)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
@@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
self.assertEqual(len(self.tracker.record_finished_requests), 0)
|
||||
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
self.assertEqual(result["req_1"], current_time)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_2")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1", "req_2"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 1)
|
||||
self.assertEqual(result_record, {"req_2"})
|
||||
|
||||
def test_updtate_add_delayed_request(self) -> None:
|
||||
self.tracker.update_done_task_count("req2")
|
||||
result_start_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_start_record), 1)
|
||||
self.tracker.add_delayed_request("req2", time.time())
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_end_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_end_record), 0)
|
||||
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
@@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
self.assertIn("req_2", result_delay)
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1")
|
||||
@@ -961,6 +986,46 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
for p in self.patches:
|
||||
p.stop() # type: ignore
|
||||
|
||||
def test_worker_use_ascend_direct(self):
|
||||
test_case = [True, False]
|
||||
|
||||
for use_ascend_direct in test_case:
|
||||
with self.subTest(use_ascend_direct=use_ascend_direct):
|
||||
config = MagicMock()
|
||||
config.kv_transfer_config = MagicMock()
|
||||
config.kv_transfer_config.get_from_extra_config.side_effect = (
|
||||
lambda k, d: {
|
||||
"prefill": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"decode": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"use_ascend_direct": use_ascend_direct,
|
||||
}.get(k, d))
|
||||
|
||||
config.parallel_config = MagicMock()
|
||||
config.parallel_config.tensor_parallel_size = 2
|
||||
config.parallel_config.data_parallel_rank_local = 0
|
||||
config.parallel_config.data_parallel_size_local = 1
|
||||
config.kv_transfer_config.kv_port = 8000
|
||||
config.kv_transfer_config.kv_role = 'worker'
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank",
|
||||
return_value=0):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tp_group",
|
||||
return_value=None):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_ip",
|
||||
return_value="127.0.0.1"):
|
||||
worker = MooncakeConnectorWorker(
|
||||
config, self.engine_id)
|
||||
self.assertIsNotNone(worker)
|
||||
|
||||
def test_register_kv_caches_producer(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -19,8 +20,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
|
||||
@@ -131,10 +130,10 @@ def create_request(
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(sha256)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
@@ -160,27 +159,14 @@ def create_request(
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
else:
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
@@ -208,26 +194,15 @@ def create_model_runner_output(
|
||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||
finished_recving=finished_recving)
|
||||
extra_args = {"kv_connector_output": kv_connector_output}
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
return model_runner_output
|
||||
|
||||
Reference in New Issue
Block a user