[Bugfix]Fixed precision issues caused by pooled request pooling (#6049)
### What this PR does / why we need it?
Fixed precision issues caused by pooled request pooling
### Does this PR introduce _any_ user-facing change?
pr6045
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: 房建伟 <fangjianwei@fangjianweideMacBook-Air.local>
Co-authored-by: 房建伟 <fangjianwei@fangjianweideMacBook-Air.local>
This commit is contained in:
@@ -43,6 +43,7 @@ class KVPoolScheduler:
|
|||||||
self._block_size *= self.dcp_size
|
self._block_size *= self.dcp_size
|
||||||
# request_id -> full_token_ids
|
# request_id -> full_token_ids
|
||||||
self._request_trackers: dict[str, RequestTracker] = {}
|
self._request_trackers: dict[str, RequestTracker] = {}
|
||||||
|
self._preempted_req_ids: set[str] = set()
|
||||||
# Whether to discard partial chunks
|
# Whether to discard partial chunks
|
||||||
self._discard_partial_chunks = (
|
self._discard_partial_chunks = (
|
||||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
@@ -161,6 +162,11 @@ class KVPoolScheduler:
|
|||||||
self._request_trackers.pop(finished_req_id, None)
|
self._request_trackers.pop(finished_req_id, None)
|
||||||
self._unfinished_requests.pop(finished_req_id, None)
|
self._unfinished_requests.pop(finished_req_id, None)
|
||||||
self._unfinished_request_ids.discard(finished_req_id)
|
self._unfinished_request_ids.discard(finished_req_id)
|
||||||
|
|
||||||
|
for req_id in scheduler_output.preempted_req_ids:
|
||||||
|
self._preempted_req_ids.update(scheduler_output.preempted_req_ids)
|
||||||
|
self._request_trackers.pop(req_id, None)
|
||||||
|
self._unfinished_requests.pop(req_id, None)
|
||||||
|
|
||||||
meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids)
|
meta = AscendConnectorMetadata(self._unfinished_request_ids, scheduler_output.preempted_req_ids)
|
||||||
|
|
||||||
@@ -170,15 +176,24 @@ class KVPoolScheduler:
|
|||||||
num_tokens_to_compute = (
|
num_tokens_to_compute = (
|
||||||
request.num_computed_tokens +
|
request.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||||
request_tracker = RequestTracker.from_new_request(
|
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||||
request, num_tokens_to_compute)
|
request_real = request_tuple[0] # type: ignore[index]
|
||||||
|
if not isinstance(request.block_ids[0], list):
|
||||||
|
unfolded_block_ids = request.block_ids.copy()
|
||||||
|
else:
|
||||||
|
unfolded_block_ids = request.block_ids[0].copy()
|
||||||
|
request_tracker = RequestTracker(
|
||||||
|
req_id=request.req_id,
|
||||||
|
token_len=num_tokens_to_compute,
|
||||||
|
allocated_block_ids=unfolded_block_ids,
|
||||||
|
num_saved_tokens=0,
|
||||||
|
)
|
||||||
self._request_trackers[request.req_id] = request_tracker
|
self._request_trackers[request.req_id] = request_tracker
|
||||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||||
self._block_size * self._block_size)
|
self._block_size * self._block_size)
|
||||||
if self._discard_partial_chunks else len(
|
if self._discard_partial_chunks else len(
|
||||||
request.prompt_token_ids))
|
request.prompt_token_ids))
|
||||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
|
||||||
request_real = request_tuple[0] # type: ignore[index]
|
|
||||||
req_meta = ReqMeta.from_request_tracker(
|
req_meta = ReqMeta.from_request_tracker(
|
||||||
request_tracker,
|
request_tracker,
|
||||||
self._block_size,
|
self._block_size,
|
||||||
@@ -195,38 +210,78 @@ class KVPoolScheduler:
|
|||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
if not force_skip_save:
|
if not force_skip_save:
|
||||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||||
request_tracker = self._request_trackers[req_id]
|
# resumed request
|
||||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
||||||
req_tuple = self._unfinished_requests.get(req_id)
|
|
||||||
if req_tuple:
|
|
||||||
request = req_tuple[0]
|
|
||||||
num_current_tokens = request_tracker.token_len
|
|
||||||
new_token_ids = request.all_token_ids[
|
|
||||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
|
||||||
request_tracker.token_len += len(new_token_ids)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Request {req_id} is not in _unfinished_requests, "
|
|
||||||
f"but it is scheduled to be cached")
|
|
||||||
new_block_ids = cached_reqs.new_block_ids[i]
|
new_block_ids = cached_reqs.new_block_ids[i]
|
||||||
if not new_block_ids:
|
if not new_block_ids:
|
||||||
continue
|
continue
|
||||||
request_tracker.update(new_block_ids)
|
if req_id in self._preempted_req_ids:
|
||||||
|
if isinstance(new_block_ids, tuple):
|
||||||
|
new_block_ids = new_block_ids[0].copy()
|
||||||
|
else:
|
||||||
|
new_block_ids = new_block_ids.copy()
|
||||||
|
self._preempted_req_ids.discard(req_id)
|
||||||
|
load_spec = self.load_specs.pop(req_id, None)
|
||||||
|
request_tuple = self._unfinished_requests.get(req_id)
|
||||||
|
request_real = request_tuple[0] # type: ignore[index]
|
||||||
|
num_tokens_to_compute = (
|
||||||
|
request_real.num_computed_tokens +
|
||||||
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
|
request_tracker = RequestTracker(
|
||||||
|
req_id=req_id,
|
||||||
|
token_len=num_tokens_to_compute,
|
||||||
|
allocated_block_ids=new_block_ids,
|
||||||
|
num_saved_tokens=0,
|
||||||
|
)
|
||||||
|
self._request_trackers[req_id] = request_tracker
|
||||||
|
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) //
|
||||||
|
self._block_size * self._block_size)
|
||||||
|
if self._discard_partial_chunks else len(
|
||||||
|
request_real.prompt_token_ids))
|
||||||
|
req_meta = ReqMeta.from_request_tracker(
|
||||||
|
request_tracker,
|
||||||
|
self._block_size,
|
||||||
|
load_spec=load_spec,
|
||||||
|
skip_save=force_skip_save,
|
||||||
|
block_hashes=request_real.block_hashes,
|
||||||
|
is_last_chunk=request_tracker.token_len
|
||||||
|
>= last_chunk_tokens_num,
|
||||||
|
discard_partial_chunks=self._discard_partial_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# decode/chunked request
|
||||||
|
else:
|
||||||
|
request_tracker = self._request_trackers[req_id]
|
||||||
|
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
|
req_tuple = self._unfinished_requests.get(req_id)
|
||||||
|
if req_tuple:
|
||||||
|
request = req_tuple[0]
|
||||||
|
num_current_tokens = request_tracker.token_len
|
||||||
|
new_token_ids = request.all_token_ids[
|
||||||
|
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||||
|
request_tracker.token_len += len(new_token_ids)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Request {req_id} is not in _unfinished_requests, "
|
||||||
|
f"but it is scheduled to be cached")
|
||||||
|
num_computed_token = cached_reqs.num_computed_tokens[i]
|
||||||
|
if num_computed_token >= len(request.prompt_token_ids):
|
||||||
|
continue
|
||||||
|
request_tracker.update(new_block_ids)
|
||||||
|
|
||||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||||
self._block_size * self._block_size)
|
self._block_size * self._block_size)
|
||||||
if self._discard_partial_chunks else
|
if self._discard_partial_chunks else
|
||||||
len(request.prompt_token_ids))
|
len(request.prompt_token_ids))
|
||||||
req_meta = ReqMeta.from_request_tracker(
|
req_meta = ReqMeta.from_request_tracker(
|
||||||
request_tracker,
|
request_tracker,
|
||||||
self._block_size,
|
self._block_size,
|
||||||
load_spec=None,
|
load_spec=None,
|
||||||
skip_save=force_skip_save,
|
skip_save=force_skip_save,
|
||||||
block_hashes=request.block_hashes,
|
block_hashes=request.block_hashes,
|
||||||
is_last_chunk=request_tracker.token_len
|
is_last_chunk=request_tracker.token_len
|
||||||
>= last_chunk_tokens_num,
|
>= last_chunk_tokens_num,
|
||||||
discard_partial_chunks=self._discard_partial_chunks,
|
discard_partial_chunks=self._discard_partial_chunks,
|
||||||
)
|
)
|
||||||
if req_meta is not None:
|
if req_meta is not None:
|
||||||
meta.add_request(req_meta)
|
meta.add_request(req_meta)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user