Auto adjust new ratio (#708)
This commit is contained in:
@@ -16,9 +16,9 @@ class GlobalConfig:
|
|||||||
self.wait_for_new_request_delay = 0.0006
|
self.wait_for_new_request_delay = 0.0006
|
||||||
|
|
||||||
# Runtime constants: New generation token ratio estimation
|
# Runtime constants: New generation token ratio estimation
|
||||||
self.base_new_token_ratio = 0.4
|
self.init_new_token_ratio = 0.7
|
||||||
self.base_min_new_token_ratio = 0.2
|
self.base_min_new_token_ratio = 0.2
|
||||||
self.new_token_ratio_decay = 0.0001
|
self.new_token_ratio_decay = 0.001
|
||||||
self.new_token_ratio_recovery = 0.05
|
self.new_token_ratio_recovery = 0.05
|
||||||
|
|
||||||
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||||
@@ -27,6 +27,7 @@ class GlobalConfig:
|
|||||||
|
|
||||||
# Runtime constants: others
|
# Runtime constants: others
|
||||||
self.num_continue_decode_steps = 10
|
self.num_continue_decode_steps = 10
|
||||||
|
self.retract_decode_steps = 20
|
||||||
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
||||||
|
|
||||||
# Output tokenization configs
|
# Output tokenization configs
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained import RegexGuide
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
@@ -431,7 +432,8 @@ class Batch:
|
|||||||
|
|
||||||
def retract_decode(self):
|
def retract_decode(self):
|
||||||
sorted_indices = [i for i in range(len(self.reqs))]
|
sorted_indices = [i for i in range(len(self.reqs))]
|
||||||
# TODO(lsyin): improve the priority of retraction
|
|
||||||
|
# TODO(lsyin): improve retraction policy for radix cache
|
||||||
sorted_indices.sort(
|
sorted_indices.sort(
|
||||||
key=lambda i: (
|
key=lambda i: (
|
||||||
len(self.reqs[i].output_ids),
|
len(self.reqs[i].output_ids),
|
||||||
@@ -443,7 +445,17 @@ class Batch:
|
|||||||
retracted_reqs = []
|
retracted_reqs = []
|
||||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||||
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
while (
|
||||||
|
self.token_to_kv_pool.available_size()
|
||||||
|
< len(sorted_indices) * global_config.retract_decode_steps
|
||||||
|
):
|
||||||
|
if len(sorted_indices) == 1:
|
||||||
|
# Corner case: only one request left
|
||||||
|
assert (
|
||||||
|
self.token_to_kv_pool.available_size() > 0
|
||||||
|
), "No space left for only one request"
|
||||||
|
break
|
||||||
|
|
||||||
idx = sorted_indices.pop()
|
idx = sorted_indices.pop()
|
||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
retracted_reqs.append(req)
|
retracted_reqs.append(req)
|
||||||
@@ -468,7 +480,16 @@ class Batch:
|
|||||||
|
|
||||||
self.filter_batch(sorted_indices)
|
self.filter_batch(sorted_indices)
|
||||||
|
|
||||||
return retracted_reqs
|
# Reqs in batch are filtered
|
||||||
|
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
||||||
|
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
||||||
|
|
||||||
|
new_estimate_ratio = (
|
||||||
|
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
|
||||||
|
) / total_max_new_tokens
|
||||||
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||||
|
|
||||||
|
return retracted_reqs, new_estimate_ratio
|
||||||
|
|
||||||
def check_for_jump_forward(self, model_runner):
|
def check_for_jump_forward(self, model_runner):
|
||||||
jump_forward_reqs = []
|
jump_forward_reqs = []
|
||||||
|
|||||||
@@ -228,6 +228,7 @@ class ModelTpServer:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
|
|
||||||
def print_stats(self):
|
def print_stats(self):
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
@@ -536,9 +537,10 @@ class ModelTpServer:
|
|||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem():
|
if not batch.check_decode_mem():
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
|
||||||
|
|
||||||
retracted_reqs = batch.retract_decode()
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||||
|
self.new_token_ratio = new_token_ratio
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"decode out of memory happened, "
|
"decode out of memory happened, "
|
||||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||||
|
|||||||
Reference in New Issue
Block a user