[PD] Support decode retract and update decode.py (#7196)
This commit is contained in:
@@ -31,7 +31,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
KVCache,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
@@ -145,7 +153,11 @@ class DecodePreallocQueue:
|
||||
gloo_group: ProcessGroup,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
gpu_id: int,
|
||||
bootstrap_port: int,
|
||||
max_total_num_tokens: int,
|
||||
prefill_pp_size: int,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
@@ -161,25 +173,35 @@ class DecodePreallocQueue:
|
||||
self.gloo_group = gloo_group
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.dp_size = dp_size
|
||||
self.gpu_id = gpu_id
|
||||
self.bootstrap_port = bootstrap_port
|
||||
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.num_reserved_decode_tokens = int(
|
||||
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
|
||||
)
|
||||
|
||||
self.transfer_backend = transfer_backend
|
||||
# Queue for requests pending pre-allocation
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.transfer_backend = transfer_backend
|
||||
self.retracted_queue: List[Req] = []
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
|
||||
def _init_kv_manager(self) -> BaseKVManager:
|
||||
kv_args = KVArgs()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||
kv_args = kv_args_class()
|
||||
|
||||
attn_tp_size = self.tp_size // self.dp_size
|
||||
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
||||
kv_args.decode_tp_size = attn_tp_size
|
||||
kv_args.prefill_pp_size = self.prefill_pp_size
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
# We should also transfer draft model kv cache. The indices are
|
||||
# always shared with a target model.
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
@@ -194,6 +216,7 @@ class DecodePreallocQueue:
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
@@ -205,27 +228,83 @@ class DecodePreallocQueue:
|
||||
)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
def add(self, req: Req, is_retracted: bool = False) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
data_parallel_rank=req.data_parallel_rank,
|
||||
)
|
||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||
if self._check_if_req_exceed_kv_capacity(req):
|
||||
return
|
||||
|
||||
def extend(self, reqs: List[Req]) -> None:
|
||||
if is_retracted:
|
||||
self.retracted_queue.append(req)
|
||||
else:
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
kv_receiver_class = get_kv_class(
|
||||
TransferBackend.FAKE, KVClassType.RECEIVER
|
||||
)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
)
|
||||
|
||||
self.queue.append(
|
||||
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
||||
)
|
||||
|
||||
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||
logger.error(message)
|
||||
prepare_abort(req, message)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
return True
|
||||
return False
|
||||
|
||||
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
for req in reqs:
|
||||
self.add(req)
|
||||
self.add(req, is_retracted=is_retracted)
|
||||
|
||||
def resume_retracted_reqs(self) -> List[Req]:
|
||||
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
|
||||
|
||||
# allocate memory
|
||||
resumed_reqs = []
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
|
||||
|
||||
for i, req in enumerate(self.retracted_queue):
|
||||
if self.req_to_token_pool.available_size() <= 0:
|
||||
break
|
||||
|
||||
required_tokens_for_request = (
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
+ self.num_reserved_decode_tokens
|
||||
)
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
resumed_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
req.is_retracted = False
|
||||
self._pre_alloc(req)
|
||||
allocatable_tokens -= required_tokens_for_request
|
||||
|
||||
# load from cpu, release the cpu copy
|
||||
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
|
||||
|
||||
self.retracted_queue = [
|
||||
entry
|
||||
for i, entry in enumerate(self.retracted_queue)
|
||||
if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return resumed_reqs
|
||||
|
||||
def _update_handshake_waiters(self) -> None:
|
||||
if not self.queue:
|
||||
@@ -255,6 +334,8 @@ class DecodePreallocQueue:
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected poll case: {poll}")
|
||||
|
||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||
@@ -262,8 +343,16 @@ class DecodePreallocQueue:
|
||||
|
||||
preallocated_reqs = []
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens()
|
||||
|
||||
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
|
||||
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
|
||||
retractable_tokens = sum(
|
||||
len(r.origin_input_ids) + len(r.output_ids)
|
||||
for r in self.scheduler.running_batch.reqs
|
||||
)
|
||||
allocatable_tokens = self._allocatable_tokens(
|
||||
retractable_tokens=retractable_tokens, count_retracted=True
|
||||
)
|
||||
# First, remove all failed requests from the queue
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||
@@ -272,6 +361,7 @@ class DecodePreallocQueue:
|
||||
)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
# Then, preallocate the remaining requests if possible
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if i in indices_to_remove:
|
||||
continue
|
||||
@@ -285,10 +375,23 @@ class DecodePreallocQueue:
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||
break
|
||||
|
||||
# Memory estimation: don't add if the projected memory cannot be met
|
||||
# TODO: add new_token ratio
|
||||
origin_input_len = len(decode_req.req.origin_input_ids)
|
||||
required_tokens_for_request = (
|
||||
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
||||
origin_input_len + self.num_reserved_decode_tokens
|
||||
)
|
||||
|
||||
if (
|
||||
max(
|
||||
required_tokens_for_request,
|
||||
origin_input_len
|
||||
+ decode_req.req.sampling_params.max_new_tokens
|
||||
- retractable_tokens,
|
||||
)
|
||||
> allocatable_tokens
|
||||
):
|
||||
break
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
@@ -321,15 +424,35 @@ class DecodePreallocQueue:
|
||||
|
||||
return preallocated_reqs
|
||||
|
||||
def _allocatable_tokens(self) -> int:
|
||||
allocatable_tokens = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
- self.num_reserved_decode_tokens
|
||||
def _allocatable_tokens(
|
||||
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
||||
) -> int:
|
||||
need_space_for_single_req = (
|
||||
max(
|
||||
[
|
||||
x.sampling_params.max_new_tokens
|
||||
+ len(x.origin_input_ids)
|
||||
- retractable_tokens
|
||||
for x in self.scheduler.running_batch.reqs
|
||||
]
|
||||
)
|
||||
if retractable_tokens is not None
|
||||
and len(self.scheduler.running_batch.reqs) > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
allocatable_tokens = available_size - max(
|
||||
# preserve some space for future decode
|
||||
self.num_reserved_decode_tokens
|
||||
* (
|
||||
len(self.scheduler.running_batch.reqs)
|
||||
+ len(self.transfer_queue.queue)
|
||||
+ len(self.scheduler.waiting_queue)
|
||||
)
|
||||
),
|
||||
# make sure each request can finish if reach max_tokens with all other requests retracted
|
||||
need_space_for_single_req,
|
||||
)
|
||||
|
||||
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
||||
@@ -342,15 +465,27 @@ class DecodePreallocQueue:
|
||||
self.scheduler.last_batch.reqs
|
||||
)
|
||||
|
||||
if count_retracted:
|
||||
allocatable_tokens -= sum(
|
||||
[
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
+ self.num_reserved_decode_tokens
|
||||
for req in self.retracted_queue
|
||||
]
|
||||
)
|
||||
return allocatable_tokens
|
||||
|
||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||
|
||||
assert req_pool_indices is not None
|
||||
assert (
|
||||
req_pool_indices is not None
|
||||
), "req_pool_indices is full! There is a bug in memory estimation."
|
||||
|
||||
req.req_pool_idx = req_pool_indices[0]
|
||||
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
@@ -375,7 +510,10 @@ class DecodePreallocQueue:
|
||||
),
|
||||
extend_num_tokens=num_tokens,
|
||||
)
|
||||
assert kv_loc is not None
|
||||
|
||||
assert (
|
||||
kv_loc is not None
|
||||
), "KV cache is full! There is a bug in memory estimation."
|
||||
|
||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||
|
||||
@@ -395,6 +533,7 @@ class DecodeTransferQueue:
|
||||
self,
|
||||
gloo_group: ProcessGroup,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
tp_rank: int,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
scheduler: Scheduler,
|
||||
tree_cache: BasePrefixCache,
|
||||
@@ -402,6 +541,7 @@ class DecodeTransferQueue:
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.scheduler = scheduler
|
||||
self.tree_cache = tree_cache
|
||||
@@ -412,10 +552,9 @@ class DecodeTransferQueue:
|
||||
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
||||
self.queue.extend(decode_reqs)
|
||||
|
||||
def pop_transferred(self) -> List[DecodeRequest]:
|
||||
def pop_transferred(self) -> List[Req]:
|
||||
if not self.queue:
|
||||
return []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
@@ -424,7 +563,7 @@ class DecodeTransferQueue:
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Failed:
|
||||
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
try:
|
||||
decode_req.kv_receiver.failure_exception()
|
||||
except Exception as e:
|
||||
@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
return new_batch
|
||||
|
||||
def process_decode_queue(self: Scheduler):
|
||||
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
|
||||
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
|
||||
self.waiting_queue.extend(resumed_reqs)
|
||||
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
|
||||
# if there are still retracted requests, we do not allocate new requests
|
||||
return
|
||||
|
||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||
alloc_reqs = (
|
||||
|
||||
@@ -25,6 +25,7 @@ from collections import deque
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||
@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if last_chunk:
|
||||
|
||||
@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
if len(retracted_reqs) == 0:
|
||||
# Corner case: only one request left
|
||||
raise ValueError(
|
||||
"Failed to retract any request. No space left for only one request."
|
||||
)
|
||||
|
||||
self.filter_batch(keep_indices=sorted_indices)
|
||||
|
||||
# Reqs in batch are filtered
|
||||
|
||||
@@ -628,6 +628,7 @@ class Scheduler(
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
tp_rank=self.tp_rank,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
@@ -650,7 +651,11 @@ class Scheduler(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
dp_size=self.server_args.dp_size,
|
||||
gpu_id=self.gpu_id,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
max_total_num_tokens=self.max_total_num_tokens,
|
||||
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
|
||||
@@ -1124,14 +1129,14 @@ class Scheduler(
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(
|
||||
reqs, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
|
||||
@@ -1274,6 +1279,7 @@ class Scheduler(
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||
|
||||
msg += (
|
||||
f"cuda graph: {can_run_cuda_graph}, "
|
||||
@@ -1575,7 +1581,7 @@ class Scheduler(
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self._extend_requests_to_queue(retracted_reqs)
|
||||
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
|
||||
@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self.head_dim = head_dim
|
||||
self._create_buffers()
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
self.chunk_size = 8192
|
||||
self.layer_transfer_counter = None
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||
@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
for layer_id in range(self.layer_num):
|
||||
kv_cache_cpu.append([])
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
||||
torch.cuda.synchronize()
|
||||
return kv_cache_cpu
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
torch.cuda.synchronize()
|
||||
for layer_id in range(self.layer_num):
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
k_cpu, v_cpu = (
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
||||
)
|
||||
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
||||
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
||||
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
||||
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
|
||||
Reference in New Issue
Block a user