Simplify batch update (#2154)
This commit is contained in:
@@ -35,7 +35,7 @@ SGLang is a fast serving framework for large language models and vision language
|
|||||||
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
||||||
The core features include:
|
The core features include:
|
||||||
|
|
||||||
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
|
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ).
|
||||||
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
|
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
|
||||||
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
|
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
|
||||||
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
|
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
|
||||||
|
|||||||
@@ -79,7 +79,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
```
|
```
|
||||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
|
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
|
||||||
```
|
```
|
||||||
- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently.
|
|
||||||
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
|
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
|
||||||
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
|
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
|
||||||
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
|
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ SGLang is a fast serving framework for large language models and vision language
|
|||||||
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
||||||
The core features include:
|
The core features include:
|
||||||
|
|
||||||
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
|
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ).
|
||||||
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
|
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
|
||||||
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models.
|
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models.
|
||||||
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
|
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ If you see out of memory (OOM) errors, you can try to tune the following paramet
|
|||||||
- You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
|
- You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
|
||||||
|
|
||||||
### Try Advanced Options
|
### Try Advanced Options
|
||||||
- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently.
|
|
||||||
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
|
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
|
||||||
|
|
||||||
### Tune `--schedule-policy`
|
### Tune `--schedule-policy`
|
||||||
|
|||||||
@@ -467,6 +467,7 @@ class ScheduleBatch:
|
|||||||
extend_lens: List[int] = None
|
extend_lens: List[int] = None
|
||||||
extend_num_tokens: int = None
|
extend_num_tokens: int = None
|
||||||
decoding_reqs: List[Req] = None
|
decoding_reqs: List[Req] = None
|
||||||
|
extend_logprob_start_lens: List[int] = None
|
||||||
|
|
||||||
# For encoder-decoder
|
# For encoder-decoder
|
||||||
encoder_cached: Optional[List[bool]] = None
|
encoder_cached: Optional[List[bool]] = None
|
||||||
@@ -722,7 +723,6 @@ class ScheduleBatch:
|
|||||||
self.merge_batch(running_batch)
|
self.merge_batch(running_batch)
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.extend_num_tokens += running_bs
|
|
||||||
|
|
||||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||||
self.prefix_lens.extend(
|
self.prefix_lens.extend(
|
||||||
@@ -732,6 +732,8 @@ class ScheduleBatch:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.extend_lens.extend([1] * running_bs)
|
self.extend_lens.extend([1] * running_bs)
|
||||||
|
self.extend_num_tokens += running_bs
|
||||||
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self):
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@@ -28,7 +27,7 @@ import torch
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -302,6 +301,9 @@ class Scheduler:
|
|||||||
) / global_config.default_new_token_ratio_decay_steps
|
) / global_config.default_new_token_ratio_decay_steps
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
|
# Tells whether the current running batch is full so that we can skip
|
||||||
|
# the check of whether to prefill new requests.
|
||||||
|
# This is an optimization to reduce the overhead of the prefill check.
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
# Init watchdog thread
|
# Init watchdog thread
|
||||||
@@ -721,40 +723,30 @@ class Scheduler:
|
|||||||
|
|
||||||
def get_next_batch_to_run(self):
|
def get_next_batch_to_run(self):
|
||||||
# Merge the prefill batch into the running batch
|
# Merge the prefill batch into the running batch
|
||||||
if (
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||||
self.last_batch
|
|
||||||
and not self.last_batch.forward_mode.is_decode()
|
|
||||||
and not self.last_batch.is_empty()
|
|
||||||
):
|
|
||||||
if self.being_chunked_req:
|
if self.being_chunked_req:
|
||||||
|
# Move the chunked request out of the batch
|
||||||
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
||||||
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
||||||
# Inflight request keeps its rid but will get a new req_pool_idx.
|
# Inflight request keeps its rid but will get a new req_pool_idx
|
||||||
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
if not self.last_batch.is_empty():
|
if not self.last_batch.is_empty():
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
self.running_batch = self.last_batch
|
self.running_batch = self.last_batch
|
||||||
else:
|
else:
|
||||||
self.running_batch.merge_batch(self.last_batch)
|
self.running_batch.merge_batch(self.last_batch)
|
||||||
|
|
||||||
# Prefill first
|
# Run prefill first if possible
|
||||||
new_batch = self.get_new_batch_prefill()
|
new_batch = self.get_new_batch_prefill()
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
# Check memory
|
|
||||||
if self.running_batch is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Run decode
|
# Run decode
|
||||||
before_bs = self.running_batch.batch_size()
|
if self.running_batch is None:
|
||||||
self.update_running_batch()
|
|
||||||
if not self.running_batch:
|
|
||||||
self.batch_is_full = False
|
|
||||||
return None
|
return None
|
||||||
if before_bs != self.running_batch.batch_size():
|
self.running_batch = self.update_running_batch(self.running_batch)
|
||||||
self.batch_is_full = False
|
|
||||||
return self.running_batch
|
return self.running_batch
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
@@ -866,15 +858,16 @@ class Scheduler:
|
|||||||
|
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def update_running_batch(self):
|
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
||||||
"""Update the current running decoding batch."""
|
"""Update the current running decoding batch."""
|
||||||
global test_retract
|
global test_retract
|
||||||
batch = self.running_batch
|
|
||||||
|
initial_bs = batch.batch_size()
|
||||||
|
|
||||||
batch.filter_batch()
|
batch.filter_batch()
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
self.running_batch = None
|
self.batch_is_full = False
|
||||||
return
|
return None
|
||||||
|
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
||||||
@@ -900,11 +893,15 @@ class Scheduler:
|
|||||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||||
self.waiting_queue.extend(jump_forward_reqs)
|
self.waiting_queue.extend(jump_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
self.running_batch = None
|
self.batch_is_full = False
|
||||||
return
|
return None
|
||||||
|
|
||||||
|
if batch.batch_size() < initial_bs:
|
||||||
|
self.batch_is_full = False
|
||||||
|
|
||||||
# Update batch tensors
|
# Update batch tensors
|
||||||
batch.prepare_for_decode(self.enable_overlap)
|
batch.prepare_for_decode(self.enable_overlap)
|
||||||
|
return batch
|
||||||
|
|
||||||
def run_batch(self, batch: ScheduleBatch):
|
def run_batch(self, batch: ScheduleBatch):
|
||||||
"""Run a batch."""
|
"""Run a batch."""
|
||||||
@@ -979,8 +976,10 @@ class Scheduler:
|
|||||||
if req.is_retracted:
|
if req.is_retracted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||||
|
raise ValueError("Unhandled error!")
|
||||||
|
|
||||||
if req.is_being_chunked <= 0:
|
if req.is_being_chunked <= 0:
|
||||||
# Inflight reqs' prefill is not finished
|
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
@@ -990,14 +989,16 @@ class Scheduler:
|
|||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req.grammar is not None:
|
|
||||||
req.grammar.accept_token(next_token_id)
|
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
|
# TODO (lianmin): need to think the case w/ mixed chunked prefill
|
||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if req.grammar is not None:
|
||||||
|
req.grammar.accept_token(next_token_id)
|
||||||
else:
|
else:
|
||||||
|
# Inflight reqs' prefill is not finished
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
@@ -1015,18 +1016,18 @@ class Scheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
req.embedding = embeddings[i]
|
req.embedding = embeddings[i]
|
||||||
if req.is_being_chunked > 0:
|
if req.is_being_chunked <= 0:
|
||||||
req.is_being_chunked -= 1
|
# Dummy output token for embedding models
|
||||||
else:
|
|
||||||
# Inflight reqs' prefill is not finished
|
|
||||||
# dummy output token for embedding models
|
|
||||||
req.output_ids.append(0)
|
req.output_ids.append(0)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
else:
|
||||||
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
# Inflight reqs' prefill is not finished
|
||||||
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
|
|
||||||
@@ -1061,9 +1062,6 @@ class Scheduler:
|
|||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.grammar is not None:
|
|
||||||
req.grammar.accept_token(next_token_id)
|
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
|
||||||
@@ -1074,6 +1072,9 @@ class Scheduler:
|
|||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
|
if req.grammar is not None:
|
||||||
|
req.grammar.accept_token(next_token_id)
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ def run_eval(args):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--num-shots", type=int, default=5)
|
parser.add_argument("--num-shots", type=int, default=5)
|
||||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
parser.add_argument("--data-path", type=str)
|
||||||
parser.add_argument("--num-questions", type=int, default=200)
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
parser.add_argument("--max-new-tokens", type=int, default=512)
|
parser.add_argument("--max-new-tokens", type=int, default=512)
|
||||||
parser.add_argument("--parallel", type=int, default=128)
|
parser.add_argument("--parallel", type=int, default=128)
|
||||||
|
|||||||
Reference in New Issue
Block a user