Mixed style of chunked prefill (#1013)

This commit is contained in:
Liangsheng Yin
2024-08-16 02:13:00 -07:00
committed by GitHub
parent 5a261bd055
commit 3694f8f996
14 changed files with 195 additions and 59 deletions

View File

@@ -111,11 +111,14 @@ class PrefillAdder:
rem_total_tokens: int, rem_total_tokens: int,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_input_tokens = rem_input_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
self.can_run_list = [] self.can_run_list = []
self.new_inflight_req = None self.new_inflight_req = None

View File

@@ -329,6 +329,9 @@ class ScheduleBatch:
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None extend_num_tokens: int = None
# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
@@ -462,9 +465,33 @@ class ScheduleBatch:
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.batch_sampling_params(vocab_size) self.batch_sampling_params(vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
self.merge(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.prefix_lens_cpu = prefix_lens_cpu
def check_decode_mem(self): def check_decode_mem(self):
bs = self.batch_size() bs = self.batch_size()
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool.available_size() >= bs:

View File

@@ -174,6 +174,9 @@ class ModelTpServer:
# Chunked prefill # Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None self.current_inflight_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
@@ -366,11 +369,14 @@ class ModelTpServer:
# Get priority queue # Get priority queue
prefix_computed = self.scheduler.calc_priority(self.waiting_queue) prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
num_mixed_running,
) )
if self.running_batch is not None: if self.running_batch is not None:
@@ -416,15 +422,27 @@ class ModelTpServer:
) )
else: else:
tree_cache_hit_rate = 0.0 tree_cache_hit_rate = 0.0
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. " if num_mixed_running > 0:
f"#new-seq: {len(can_run_list)}, " logger.info(
f"#new-token: {adder.log_input_tokens}, " f"[gpu={self.gpu_id}] Prefill batch"
f"#cached-token: {adder.log_hit_tokens}, " f"(mixed #running-req: {num_mixed_running}). "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"#new-seq: {len(can_run_list)}, "
f"#running-req: {running_bs}, " f"#new-token: {adder.log_input_tokens}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" f"#cached-token: {adder.log_hit_tokens}, "
) f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
else:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
# Return the new batch # Return the new batch
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
@@ -440,6 +458,13 @@ class ModelTpServer:
# Build batch tensors # Build batch tensors
batch.prepare_for_extend(self.model_config.vocab_size) batch.prepare_for_extend(self.model_config.vocab_size)
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs
self.running_batch = None
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
@@ -481,7 +506,8 @@ class ModelTpServer:
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
else: elif req not in decoding_reqs:
# To reduce overhead, only cache prefill reqs
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req: if req is self.current_inflight_req:

View File

@@ -88,11 +88,11 @@ class InputMetadata:
self.image_sizes = [r.image_size for r in reqs] self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [ self.image_offsets = [
( (
(r.image_offset - len(r.prefix_indices)) (r.image_offset - batch.prefix_lens_cpu[i])
if r.image_offset is not None if r.image_offset is not None
else 0 else 0
) )
for r in reqs for i, r in enumerate(reqs)
] ]
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
@@ -109,8 +109,8 @@ class InputMetadata:
self.positions = torch.tensor( self.positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
np.arange(len(req.prefix_indices), len(req.fill_ids)) np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for req in batch.reqs for i, req in enumerate(batch.reqs)
], ],
axis=0, axis=0,
), ),
@@ -123,7 +123,7 @@ class InputMetadata:
np.concatenate( np.concatenate(
[ [
np.arange( np.arange(
len(req.prefix_indices) + position_ids_offsets_cpu[i], batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
len(req.fill_ids) + position_ids_offsets_cpu[i], len(req.fill_ids) + position_ids_offsets_cpu[i],
) )
for i, req in enumerate(batch.reqs) for i, req in enumerate(batch.reqs)
@@ -141,12 +141,13 @@ class InputMetadata:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
else: else:
extend_lens_cpu = [ extend_lens_cpu = [
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs len(r.fill_ids) - batch.prefix_lens_cpu[i]
for i, r in enumerate(batch.reqs)
] ]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
@@ -180,14 +181,8 @@ class InputMetadata:
if forward_mode != ForwardMode.DECODE: if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch) ret.init_multimuldal_info(batch)
prefix_lens = None
if forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
)
if model_runner.server_args.disable_flashinfer: if model_runner.server_args.disable_flashinfer:
ret.init_triton_args(batch, prefix_lens) ret.init_triton_args(batch)
flashinfer_use_ragged = False flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer: if not model_runner.server_args.disable_flashinfer:
@@ -198,30 +193,35 @@ class InputMetadata:
): ):
flashinfer_use_ragged = True flashinfer_use_ragged = True
ret.init_flashinfer_handlers( ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
) )
return ret return ret
def init_triton_args(self, batch: ScheduleBatch, prefix_lens): def init_triton_args(self, batch: ScheduleBatch):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = int(torch.max(self.seq_lens)) self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
self.triton_max_extend_len = None self.triton_max_extend_len = None
else: else:
extend_seq_lens = self.seq_lens - prefix_lens self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens)) self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_handlers( def init_flashinfer_handlers(
self, self,
model_runner, model_runner,
prefix_lens, prefix_lens_cpu,
flashinfer_use_ragged, flashinfer_use_ragged,
): ):
if self.forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
else:
prefix_lens = None
update_flashinfer_indices( update_flashinfer_indices(
self.forward_mode, self.forward_mode,
model_runner, model_runner,

View File

@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
print(f"Initialization failed. warmup error: {last_traceback}", flush=True) print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1) sys.exit(1)
# Print warnings here
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
logger.warning(
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
"This combination is an experimental feature and we noticed it can lead to "
"wrong generation results. If you want to use chunked prefill, it is recommended "
"not using `--disable-radix-cache`."
)
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok") pipe_finish_writer.send("init ok")

View File

@@ -80,6 +80,7 @@ class ServerArgs:
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
enable_mla: bool = False enable_mla: bool = False
@@ -396,6 +397,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a chunked batch.",
)
parser.add_argument( parser.add_argument(
"--enable-torch-compile", "--enable-torch-compile",
action="store_true", action="store_true",

View File

@@ -1,13 +1,12 @@
# Adapted from https://github.com/openai/simple-evals/ # Adapted from https://github.com/openai/simple-evals/
import base64
import os import os
import resource import resource
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import httpx import httpx
import jinja2 import jinja2
@@ -44,8 +43,8 @@ class EvalResult:
Result of running an evaluation (usually consisting of many samples) Result of running an evaluation (usually consisting of many samples)
""" """
score: float | None # top-line metric score: Optional[float] # top-line metric
metrics: Dict[str, float] | None # other metrics metrics: Optional[Dict[str, float]] # other metrics
htmls: List[str] # strings of valid HTML htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations convos: List[MessageList] # sampled conversations
@@ -56,10 +55,10 @@ class SingleEvalResult:
Result of evaluating a single sample Result of evaluating a single sample
""" """
score: float | None score: Optional[float]
metrics: Dict[str, float] = field(default_factory=dict) metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None html: Optional[str] = None
convo: MessageList | None = None # sampled conversation convo: Optional[MessageList] = None # sampled conversation
class Eval: class Eval:
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
def __init__( def __init__(
self, self,
base_url: str = None, base_url: str = None,
model: str | None = None, model: Optional[str] = None,
system_message: str | None = None, system_message: Optional[str] = None,
temperature: float = 0.0, temperature: float = 0.0,
max_tokens: int = 2048, max_tokens: int = 2048,
): ):
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
def aggregate_results( def aggregate_results(
single_eval_results: List[SingleEvalResult], single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"), default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None, name2stats: Optional[Dict[str, Tuple[str]]] = None,
) -> EvalResult: ) -> EvalResult:
""" """
Aggregate results from multiple evaluations into a single EvalResult. Aggregate results from multiple evaluations into a single EvalResult.

View File

@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
import random import random
import re import re
from typing import Optional
import pandas import pandas
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
def __init__( def __init__(
self, self,
filename: str, filename: str,
num_examples: int | None, num_examples: Optional[int],
num_threads: int, num_threads: int,
n_repeats: int = 1, n_repeats: int = 1,
): ):

View File

@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
import random import random
import re import re
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List from typing import Dict, List, Optional
import tqdm import tqdm
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
class HumanEval(Eval): class HumanEval(Eval):
def __init__( def __init__(
self, self,
num_examples: int | None, num_examples: Optional[int],
num_threads: int, num_threads: int,
num_samples_per_task: int = 5, num_samples_per_task: int = 5,
ks_passes: List[int] = [1, 2, 5], ks_passes: List[int] = [1, 2, 5],

View File

@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
import random import random
import re import re
from typing import Optional
import pandas import pandas
@@ -36,7 +37,7 @@ class MathEval(Eval):
self, self,
filename: str, filename: str,
equality_checker: SamplerBase, equality_checker: SamplerBase,
num_examples: int | None, num_examples: Optional[int],
num_threads: int, num_threads: int,
): ):
df = pandas.read_csv(filename) df = pandas.read_csv(filename)

View File

@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
import random import random
import re import re
from typing import Optional
import pandas import pandas
@@ -84,7 +85,7 @@ subject2category = {
class MMLUEval(Eval): class MMLUEval(Eval):
def __init__(self, filename: str, num_examples: int | None, num_threads: int): def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
df = pandas.read_csv(filename) df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()] examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples: if num_examples:

View File

@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
class TestChunkedPrefill(unittest.TestCase): class TestChunkedPrefill(unittest.TestCase):
def run_mmlu(self, disable_radix_cache): def run_mmlu(self, disable_radix_cache, enable_mixed_chunk):
other_args = ["--chunked-prefill-size", "32"] other_args = ["--chunked-prefill-size", "32"]
if disable_radix_cache: if disable_radix_cache:
other_args += ["--disable-radix-cache"] other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
other_args += ["--enable-mixed-chunk"]
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_UNIT_TEST base_url = DEFAULT_URL_FOR_UNIT_TEST
process = popen_launch_server( process = popen_launch_server(
@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
kill_child_process(process.pid) kill_child_process(process.pid)
def test_chunked_prefill(self): def test_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False) self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=False)
def test_mixed_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=True)
def test_chunked_prefill_without_radix_cache(self): def test_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True) self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=False)
def test_mixed_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_ACCURACY_TEST, DEFAULT_URL_FOR_ACCURACY_TEST,
DEFAULT_URL_FOR_UNIT_TEST,
popen_launch_server, popen_launch_server,
) )

View File

@@ -0,0 +1,73 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_ACCURACY_TEST,
popen_launch_server,
)
class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--log-level-http",
"warning",
"--chunked-prefill-size",
"256",
"--enable-mixed-chunk",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=3000,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.71, f"{metrics}"
def test_human_eval(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="humaneval",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.64, f"{metrics}"
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.84, f"{metrics}"
if __name__ == "__main__":
unittest.main()