Auto adjust new ratio (#708)
This commit is contained in:
@@ -16,9 +16,9 @@ class GlobalConfig:
|
||||
self.wait_for_new_request_delay = 0.0006
|
||||
|
||||
# Runtime constants: New generation token ratio estimation
|
||||
self.base_new_token_ratio = 0.4
|
||||
self.init_new_token_ratio = 0.7
|
||||
self.base_min_new_token_ratio = 0.2
|
||||
self.new_token_ratio_decay = 0.0001
|
||||
self.new_token_ratio_decay = 0.001
|
||||
self.new_token_ratio_recovery = 0.05
|
||||
|
||||
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||
@@ -27,6 +27,7 @@ class GlobalConfig:
|
||||
|
||||
# Runtime constants: others
|
||||
self.num_continue_decode_steps = 10
|
||||
self.retract_decode_steps = 20
|
||||
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
||||
|
||||
# Output tokenization configs
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
@@ -431,7 +432,8 @@ class Batch:
|
||||
|
||||
def retract_decode(self):
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
# TODO(lsyin): improve the priority of retraction
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
@@ -443,7 +445,17 @@ class Batch:
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
assert (
|
||||
self.token_to_kv_pool.available_size() > 0
|
||||
), "No space left for only one request"
|
||||
break
|
||||
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
@@ -468,7 +480,16 @@ class Batch:
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
|
||||
return retracted_reqs
|
||||
# Reqs in batch are filtered
|
||||
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
||||
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
||||
|
||||
new_estimate_ratio = (
|
||||
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
|
||||
) / total_max_new_tokens
|
||||
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def check_for_jump_forward(self, model_runner):
|
||||
jump_forward_reqs = []
|
||||
|
||||
@@ -228,6 +228,7 @@ class ModelTpServer:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
|
||||
def print_stats(self):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
@@ -536,9 +537,10 @@ class ModelTpServer:
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
||||
|
||||
retracted_reqs = batch.retract_decode()
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||
self.new_token_ratio = new_token_ratio
|
||||
|
||||
logger.info(
|
||||
"decode out of memory happened, "
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
|
||||
Reference in New Issue
Block a user