Fix flush cache API for spec v2 (#11918)
This commit is contained in:
@@ -27,3 +27,8 @@ class BaseSpecWorker(ABC):
|
||||
@abstractmethod
|
||||
def draft_worker(self) -> BaseDraftWorker:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache_pool(self):
|
||||
# TODO: move this abstract method to BaseTpWorker and call through self.model_runner
|
||||
pass
|
||||
|
||||
@@ -613,8 +613,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
return parent_list, top_scores_index, draft_tokens
|
||||
|
||||
def clear_cache_pool(self):
|
||||
self.model_runner.req_to_token_pool.clear()
|
||||
self.model_runner.token_to_kv_pool_allocator.clear()
|
||||
# allocator and kv cache pool are shared with target worker
|
||||
pass
|
||||
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
|
||||
@@ -539,6 +539,10 @@ class EAGLEWorkerV2(BaseSpecWorker):
|
||||
def draft_worker(self):
|
||||
return self._draft_worker
|
||||
|
||||
def clear_cache_pool(self):
|
||||
# allocator and kv cache pool are shared with target worker, which are cleared in scheduler
|
||||
pass
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
if model_worker_batch.forward_mode.is_decode():
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
|
||||
Reference in New Issue
Block a user