Upstreaming hicache bug fixes (#7267)
This commit is contained in:
@@ -239,7 +239,7 @@ class WorkloadGenerator:
|
|||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
)
|
)
|
||||||
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
|
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
|
||||||
|
|
||||||
init_requests = [
|
init_requests = [
|
||||||
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
||||||
|
|||||||
@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class LayerDoneCounter:
|
class LayerDoneCounter:
|
||||||
def __init__(self, num_layers):
|
def __init__(self, num_layers):
|
||||||
self.counter = num_layers
|
self.num_layers = num_layers
|
||||||
self.condition = threading.Condition()
|
# extra producer and consumer counters for overlap mode
|
||||||
|
self.num_counters = 3
|
||||||
|
self.counters = [num_layers] * self.num_counters
|
||||||
|
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
|
||||||
|
self.producer_index = 0
|
||||||
|
self.consumer_index = 0
|
||||||
|
|
||||||
|
def next_producer(self):
|
||||||
|
return (self.producer_index + 1) % self.num_counters
|
||||||
|
|
||||||
|
def update_producer(self):
|
||||||
|
self.producer_index = self.next_producer()
|
||||||
|
return self.producer_index
|
||||||
|
|
||||||
|
def set_consumer(self, index):
|
||||||
|
self.consumer_index = index
|
||||||
|
|
||||||
def increment(self):
|
def increment(self):
|
||||||
with self.condition:
|
with self.conditions[self.producer_index]:
|
||||||
self.counter += 1
|
self.counters[self.producer_index] += 1
|
||||||
self.condition.notify_all()
|
self.conditions[self.producer_index].notify_all()
|
||||||
|
|
||||||
def wait_until(self, threshold):
|
def wait_until(self, threshold):
|
||||||
with self.condition:
|
with self.conditions[self.consumer_index]:
|
||||||
while self.counter <= threshold:
|
while self.counters[self.consumer_index] <= threshold:
|
||||||
self.condition.wait()
|
self.conditions[self.consumer_index].wait()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
with self.condition:
|
with self.conditions[self.producer_index]:
|
||||||
self.counter = 0
|
self.counters[self.producer_index] = 0
|
||||||
|
|
||||||
|
|
||||||
class CacheOperation:
|
class CacheOperation:
|
||||||
@@ -296,7 +311,6 @@ class HiCacheController:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.load_queue.get(block=True, timeout=1)
|
operation = self.load_queue.get(block=True, timeout=1)
|
||||||
# time.sleep(18e-6 * len(operation.host_indices))
|
|
||||||
operation.data = self.mem_pool_host.get_flat_data(
|
operation.data = self.mem_pool_host.get_flat_data(
|
||||||
operation.host_indices
|
operation.host_indices
|
||||||
)
|
)
|
||||||
@@ -320,6 +334,7 @@ class HiCacheController:
|
|||||||
if not self.load_cache_event.is_set():
|
if not self.load_cache_event.is_set():
|
||||||
continue
|
continue
|
||||||
self.load_cache_event.clear()
|
self.load_cache_event.clear()
|
||||||
|
self.layer_done_counter.update_producer()
|
||||||
|
|
||||||
batch_operation = None
|
batch_operation = None
|
||||||
while self.load_queue.qsize() > 0:
|
while self.load_queue.qsize() > 0:
|
||||||
@@ -331,6 +346,7 @@ class HiCacheController:
|
|||||||
if batch_operation is None:
|
if batch_operation is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# start layer-wise KV cache transfer from CPU to GPU
|
||||||
self.layer_done_counter.reset()
|
self.layer_done_counter.reset()
|
||||||
for i in range(self.mem_pool_host.layer_num):
|
for i in range(self.mem_pool_host.layer_num):
|
||||||
if self.page_size == 1:
|
if self.page_size == 1:
|
||||||
@@ -466,6 +482,7 @@ class HiCacheController:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
||||||
|
# todo (zhiqiang): double buffering to be deprecated
|
||||||
def write_thread_func_buffer(self):
|
def write_thread_func_buffer(self):
|
||||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
||||||
aux_thread.start()
|
aux_thread.start()
|
||||||
|
|||||||
@@ -659,14 +659,6 @@ class Req:
|
|||||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||||
)
|
)
|
||||||
elif enable_hierarchical_cache:
|
|
||||||
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
|
||||||
while self.last_node.evicted:
|
|
||||||
self.prefix_indices = self.prefix_indices[
|
|
||||||
: -len(self.last_node.host_value)
|
|
||||||
]
|
|
||||||
self.last_node = self.last_node.parent
|
|
||||||
|
|
||||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||||
|
|
||||||
def adjust_max_prefix_ids(self):
|
def adjust_max_prefix_ids(self):
|
||||||
@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
||||||
|
hicache_consumer_index: int = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
token_type_ids=self.token_type_ids,
|
token_type_ids=self.token_type_ids,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
spec_info=self.spec_info,
|
spec_info=self.spec_info,
|
||||||
|
hicache_consumer_index=self.hicache_consumer_index,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
CaptureHiddenMode.FULL
|
CaptureHiddenMode.FULL
|
||||||
if self.return_hidden_states
|
if self.return_hidden_states
|
||||||
@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
|
|||||||
# If set, the output of the batch contains the hidden states of the run.
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
spec_num_draft_tokens: Optional[int] = None
|
spec_num_draft_tokens: Optional[int] = None
|
||||||
|
hicache_consumer_index: int = 0
|
||||||
|
|
||||||
# Overlap event
|
# Overlap event
|
||||||
launch_done: Optional[threading.Event] = None
|
launch_done: Optional[threading.Event] = None
|
||||||
|
|||||||
@@ -565,6 +565,10 @@ class Scheduler(
|
|||||||
hicache_size=server_args.hicache_size,
|
hicache_size=server_args.hicache_size,
|
||||||
hicache_write_policy=server_args.hicache_write_policy,
|
hicache_write_policy=server_args.hicache_write_policy,
|
||||||
)
|
)
|
||||||
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||||
|
self.tree_cache.cache_controller.layer_done_counter
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.tree_cache = RadixCache(
|
self.tree_cache = RadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
@@ -1514,8 +1518,13 @@ class Scheduler(
|
|||||||
self.running_batch.batch_is_full = True
|
self.running_batch.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# bypass prefix_computed if enable_hierarchical_cache
|
||||||
req.init_next_round_input(
|
req.init_next_round_input(
|
||||||
None if prefix_computed else self.tree_cache,
|
(
|
||||||
|
None
|
||||||
|
if (prefix_computed and not self.enable_hierarchical_cache)
|
||||||
|
else self.tree_cache
|
||||||
|
),
|
||||||
self.enable_hierarchical_cache,
|
self.enable_hierarchical_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1548,9 +1557,6 @@ class Scheduler(
|
|||||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.enable_hierarchical_cache:
|
|
||||||
self.tree_cache.ready_to_load_cache()
|
|
||||||
|
|
||||||
if adder.new_chunked_req is not None:
|
if adder.new_chunked_req is not None:
|
||||||
assert self.chunked_req is None
|
assert self.chunked_req is None
|
||||||
self.chunked_req = adder.new_chunked_req
|
self.chunked_req = adder.new_chunked_req
|
||||||
@@ -1574,6 +1580,10 @@ class Scheduler(
|
|||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
chunked_req=self.chunked_req,
|
chunked_req=self.chunked_req,
|
||||||
)
|
)
|
||||||
|
if self.enable_hierarchical_cache:
|
||||||
|
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
||||||
|
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
|
||||||
|
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
# Mixed-style chunked prefill
|
# Mixed-style chunked prefill
|
||||||
@@ -1649,6 +1659,11 @@ class Scheduler(
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|
||||||
|
# update the consumer index of hicache to the running batch
|
||||||
|
self.tp_worker.set_hicache_consumer(
|
||||||
|
model_worker_batch.hicache_consumer_index
|
||||||
|
)
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
|
|||||||
@@ -147,6 +147,15 @@ class TpModelWorker:
|
|||||||
# A reference make this class has the same member as TpModelWorkerClient
|
# A reference make this class has the same member as TpModelWorkerClient
|
||||||
self.worker = self
|
self.worker = self
|
||||||
|
|
||||||
|
self.hicache_layer_transfer_counter = None
|
||||||
|
|
||||||
|
def register_hicache_layer_transfer_counter(self, counter):
|
||||||
|
self.hicache_layer_transfer_counter = counter
|
||||||
|
|
||||||
|
def set_hicache_consumer(self, consumer_index):
|
||||||
|
if self.hicache_layer_transfer_counter is not None:
|
||||||
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
||||||
|
|
||||||
def get_worker_info(self):
|
def get_worker_info(self):
|
||||||
return (
|
return (
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
|
|||||||
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
|
|||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
||||||
|
|
||||||
|
self.hicache_layer_transfer_counter = None
|
||||||
|
|
||||||
|
def register_hicache_layer_transfer_counter(self, counter):
|
||||||
|
self.hicache_layer_transfer_counter = counter
|
||||||
|
|
||||||
|
def set_hicache_consumer(self, consumer_index):
|
||||||
|
if self.hicache_layer_transfer_counter is not None:
|
||||||
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
||||||
|
|
||||||
def get_worker_info(self):
|
def get_worker_info(self):
|
||||||
return self.worker.get_worker_info()
|
return self.worker.get_worker_info()
|
||||||
|
|
||||||
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
|
|||||||
input_ids = model_worker_batch.input_ids
|
input_ids = model_worker_batch.input_ids
|
||||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||||
|
|
||||||
|
# update the consumer index of hicache to the running batch
|
||||||
|
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.worker.forward_batch_generation(
|
self.worker.forward_batch_generation(
|
||||||
|
|||||||
@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
|
|||||||
return last_node, prefix_indices
|
return last_node, prefix_indices
|
||||||
|
|
||||||
def ready_to_load_cache(self):
|
def ready_to_load_cache(self):
|
||||||
|
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
||||||
self.load_cache_event.set()
|
self.load_cache_event.set()
|
||||||
|
return producer_index
|
||||||
|
|
||||||
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
||||||
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
||||||
@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
|
|||||||
new_node.lock_ref = child.lock_ref
|
new_node.lock_ref = child.lock_ref
|
||||||
new_node.key = child.key[:split_len]
|
new_node.key = child.key[:split_len]
|
||||||
new_node.loading = child.loading
|
new_node.loading = child.loading
|
||||||
|
new_node.hit_count = child.hit_count
|
||||||
|
|
||||||
# split value and host value if exists
|
# split value and host value if exists
|
||||||
if child.evicted:
|
if child.evicted:
|
||||||
|
|||||||
Reference in New Issue
Block a user