Clean up allocators (#9134)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -113,6 +113,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"quantization",
|
||||
"enable_custom_logit_processor",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -909,9 +910,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
@@ -928,7 +926,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
chunked_req: Optional[Req] = None,
|
||||
):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
@@ -955,7 +952,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||
chunked_req=chunked_req,
|
||||
)
|
||||
@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
# Over estimate the number of tokens: assume each request needs a new page.
|
||||
num_tokens = (
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
# Over estimate the number of tokens: assume each request needs a new page.
|
||||
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
||||
else:
|
||||
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
attention_backend_str
|
||||
in [
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
"ascend",
|
||||
"trtllm_mha",
|
||||
"aiter",
|
||||
]
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
)
|
||||
else:
|
||||
seq_lens_cpu = None
|
||||
|
||||
if self.sampling_info:
|
||||
if self.has_grammar:
|
||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||
else:
|
||||
self.sampling_info.grammars = None
|
||||
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
)
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
return ModelWorkerBatch(
|
||||
@@ -1815,18 +1792,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
if isinstance(self.tree_cache, SWAChunkCache):
|
||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
||||
return
|
||||
|
||||
if self.is_hybrid:
|
||||
|
||||
@@ -1634,7 +1634,6 @@ class Scheduler(
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
chunked_req=self.chunked_req,
|
||||
)
|
||||
if self.enable_hierarchical_cache:
|
||||
@@ -2031,7 +2030,6 @@ class Scheduler(
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
Reference in New Issue
Block a user