[Bugfix]Fixed precision issues caused by pooled request pooling (#6049)
### What this PR does / why we need it?
Fixed precision issues caused by pooled request pooling
### Does this PR introduce _any_ user-facing change?
pr6045
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: 房建伟 <fangjianwei@fangjianweideMacBook-Air.local>
Co-authored-by: 房建伟 <fangjianwei@fangjianweideMacBook-Air.local>
This commit is contained in:
@@ -43,6 +43,7 @@ class KVPoolScheduler:
|
||||
self._block_size *= self.dcp_size
|
||||
# request_id -> full_token_ids
|
||||
self._request_trackers: dict[str, RequestTracker] = {}
|
||||
self._preempted_req_ids: set[str] = set()
|
||||
# Whether to discard partial chunks
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
@@ -162,6 +163,11 @@ class KVPoolScheduler:
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.discard(finished_req_id)
|
||||
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self._preempted_req_ids.update(scheduler_output.preempted_req_ids)
|
||||
self._request_trackers.pop(req_id, None)
|
||||
self._unfinished_requests.pop(req_id, None)
|
||||
|
||||
meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids)
|
||||
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
@@ -170,15 +176,24 @@ class KVPoolScheduler:
|
||||
num_tokens_to_compute = (
|
||||
request.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||
request_tracker = RequestTracker.from_new_request(
|
||||
request, num_tokens_to_compute)
|
||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
if not isinstance(request.block_ids[0], list):
|
||||
unfolded_block_ids = request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = request.block_ids[0].copy()
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request.req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
@@ -195,38 +210,78 @@ class KVPoolScheduler:
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
if not force_skip_save:
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = request_tracker.token_len
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
request_tracker.token_len += len(new_token_ids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
# resumed request
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if not new_block_ids:
|
||||
continue
|
||||
request_tracker.update(new_block_ids)
|
||||
if req_id in self._preempted_req_ids:
|
||||
if isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0].copy()
|
||||
else:
|
||||
new_block_ids = new_block_ids.copy()
|
||||
self._preempted_req_ids.discard(req_id)
|
||||
load_spec = self.load_specs.pop(req_id, None)
|
||||
request_tuple = self._unfinished_requests.get(req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
num_tokens_to_compute = (
|
||||
request_real.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
request_tracker = RequestTracker(
|
||||
req_id=req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=new_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
self._request_trackers[req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request_real.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
# decode/chunked request
|
||||
else:
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = request_tracker.token_len
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
request_tracker.token_len += len(new_token_ids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
num_computed_token = cached_reqs.num_computed_tokens[i]
|
||||
if num_computed_token >= len(request.prompt_token_ids):
|
||||
continue
|
||||
request_tracker.update(new_block_ids)
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user