Fix retract for page size > 1 (#4914)
This commit is contained in:
46
.github/workflows/pr-test.yml
vendored
46
.github/workflows/pr-test.yml
vendored
@@ -87,53 +87,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Test data parallelism (DP=2)
|
- name: Run test
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_data_parallelism.py
|
python3 run_suite.py --suite per-commit-2-gpu
|
||||||
|
|
||||||
- name: Test data parallelism attention (DP=2)
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt
|
|
||||||
python3 test_dp_attention.py
|
|
||||||
|
|
||||||
- name: Test update weights from distributed
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt
|
|
||||||
python3 test_update_weights_from_distributed.py
|
|
||||||
|
|
||||||
- name: Test VerlEngine
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt
|
|
||||||
python3 test_verl_engine.py
|
|
||||||
|
|
||||||
- name: Test Patch Torch
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt
|
|
||||||
python3 test_patch_torch.py
|
|
||||||
|
|
||||||
- name: Test expert parallelism (EP=2)
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt
|
|
||||||
python3 test_moe_ep.py
|
|
||||||
|
|
||||||
- name: Test torch compile (TP=2)
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt
|
|
||||||
python3 test_mla_tp.py
|
|
||||||
|
|
||||||
- name: Test lora tensor parallelism (TP=2)
|
|
||||||
timeout-minutes: 10
|
|
||||||
run: |
|
|
||||||
cd test/srt/models/lora
|
|
||||||
python3 test_lora_tp.py
|
|
||||||
|
|
||||||
performance-test-1-gpu-part-1:
|
performance-test-1-gpu-part-1:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
|||||||
@@ -169,7 +169,9 @@ class BaseGrammarBackend(ABC):
|
|||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
|
def create_grammar_backend(
|
||||||
|
server_args: ServerArgs, tokenizer, vocab_size: int
|
||||||
|
) -> Optional[BaseGrammarBackend]:
|
||||||
if server_args.grammar_backend == "outlines":
|
if server_args.grammar_backend == "outlines":
|
||||||
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
||||||
|
|
||||||
@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
)
|
)
|
||||||
|
elif server_args.grammar_backend == "none":
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||||
|
|
||||||
|
|||||||
@@ -599,6 +599,7 @@ class Req:
|
|||||||
self.extend_logprob_start_len = 0
|
self.extend_logprob_start_len = 0
|
||||||
self.is_chunked = 0
|
self.is_chunked = 0
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
|
self.already_computed = 0
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# If req.input_embeds is already a list, append its content directly
|
# If req.input_embeds is already a list, append its content directly
|
||||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||||
|
|
||||||
if req.is_retracted:
|
|
||||||
req.already_computed = 0
|
|
||||||
req.cached_tokens += pre_len - req.already_computed
|
req.cached_tokens += pre_len - req.already_computed
|
||||||
req.already_computed = seq_len
|
req.already_computed = seq_len
|
||||||
req.is_retracted = False
|
req.is_retracted = False
|
||||||
@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
else:
|
else:
|
||||||
# TODO: apply more fine-grained retraction
|
# TODO: apply more fine-grained retraction
|
||||||
last_uncached_pos = len(req.prefix_indices)
|
last_uncached_pos = (
|
||||||
|
(len(req.prefix_indices) + server_args.page_size - 1)
|
||||||
|
// server_args.page_size
|
||||||
|
* server_args.page_size
|
||||||
|
)
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
|
|||||||
|
|
||||||
def __init__(self, labels: Dict[str, str]) -> None:
|
def __init__(self, labels: Dict[str, str]) -> None:
|
||||||
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge, Histogram
|
||||||
|
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.last_log_time = time.time()
|
self.last_log_time = time.time()
|
||||||
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
0.1,
|
0.1,
|
||||||
0.3,
|
0.2,
|
||||||
0.5,
|
0.4,
|
||||||
0.7,
|
0.6,
|
||||||
0.9,
|
0.8,
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
4,
|
4,
|
||||||
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
|
|||||||
40,
|
40,
|
||||||
60,
|
60,
|
||||||
80,
|
80,
|
||||||
120,
|
100,
|
||||||
160,
|
200,
|
||||||
],
|
400,
|
||||||
)
|
|
||||||
|
|
||||||
self.histogram_time_per_output_token = Histogram(
|
|
||||||
name="sglang:time_per_output_token_seconds",
|
|
||||||
documentation="Histogram of time per output token in seconds.",
|
|
||||||
labelnames=labels.keys(),
|
|
||||||
buckets=[
|
|
||||||
0.002,
|
|
||||||
0.005,
|
|
||||||
0.010,
|
|
||||||
0.020,
|
|
||||||
0.030,
|
|
||||||
0.040,
|
|
||||||
0.050,
|
|
||||||
0.060,
|
|
||||||
0.070,
|
|
||||||
0.080,
|
|
||||||
0.090,
|
|
||||||
0.100,
|
|
||||||
0.150,
|
|
||||||
0.200,
|
|
||||||
0.300,
|
|
||||||
0.400,
|
|
||||||
0.600,
|
|
||||||
0.800,
|
|
||||||
1.000,
|
|
||||||
2.000,
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
|
|||||||
0.030,
|
0.030,
|
||||||
0.035,
|
0.035,
|
||||||
0.040,
|
0.040,
|
||||||
0.050,
|
0.060,
|
||||||
0.075,
|
0.080,
|
||||||
0.100,
|
0.100,
|
||||||
0.150,
|
|
||||||
0.200,
|
0.200,
|
||||||
0.300,
|
|
||||||
0.400,
|
0.400,
|
||||||
0.500,
|
0.600,
|
||||||
0.750,
|
0.800,
|
||||||
1.000,
|
1.000,
|
||||||
2.000,
|
2.000,
|
||||||
|
4.000,
|
||||||
|
6.000,
|
||||||
|
8.000,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
|
|||||||
0.1,
|
0.1,
|
||||||
0.2,
|
0.2,
|
||||||
0.4,
|
0.4,
|
||||||
|
0.6,
|
||||||
0.8,
|
0.8,
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
5,
|
4,
|
||||||
|
6,
|
||||||
|
8,
|
||||||
10,
|
10,
|
||||||
20,
|
20,
|
||||||
40,
|
40,
|
||||||
60,
|
60,
|
||||||
80,
|
80,
|
||||||
100,
|
100,
|
||||||
150,
|
|
||||||
200,
|
200,
|
||||||
250,
|
400,
|
||||||
300,
|
800,
|
||||||
350,
|
|
||||||
500,
|
|
||||||
1000,
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
|
|||||||
):
|
):
|
||||||
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
||||||
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
||||||
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
if cached_tokens > 0:
|
||||||
|
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
||||||
self.num_requests_total.labels(**self.labels).inc(1)
|
self.num_requests_total.labels(**self.labels).inc(1)
|
||||||
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
||||||
if generation_tokens >= 1:
|
|
||||||
self.histogram_time_per_output_token.labels(**self.labels).observe(
|
|
||||||
e2e_latency / generation_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def observe_time_to_first_token(self, value: float):
|
def observe_time_to_first_token(self, value: float):
|
||||||
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class ServerArgs:
|
|||||||
# Kernel backend
|
# Kernel backend
|
||||||
attention_backend: Optional[str] = None
|
attention_backend: Optional[str] = None
|
||||||
sampling_backend: Optional[str] = None
|
sampling_backend: Optional[str] = None
|
||||||
grammar_backend: Optional[str] = "xgrammar"
|
grammar_backend: Optional[str] = None
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
speculative_algorithm: Optional[str] = None
|
speculative_algorithm: Optional[str] = None
|
||||||
@@ -193,6 +193,13 @@ class ServerArgs:
|
|||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Expert parallelism
|
||||||
|
if self.enable_ep_moe:
|
||||||
|
self.ep_size = self.tp_size
|
||||||
|
logger.info(
|
||||||
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
|
)
|
||||||
|
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
@@ -274,12 +281,9 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
|
|
||||||
# Expert parallelism
|
# Choose grammar backend
|
||||||
if self.enable_ep_moe:
|
if self.grammar_backend is None:
|
||||||
self.ep_size = self.tp_size
|
self.grammar_backend = "xgrammar"
|
||||||
logger.info(
|
|
||||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Data parallelism attention
|
# Data parallelism attention
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
@@ -813,7 +817,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grammar-backend",
|
"--grammar-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["xgrammar", "outlines", "llguidance"],
|
choices=["xgrammar", "outlines", "llguidance", "none"],
|
||||||
default=ServerArgs.grammar_backend,
|
default=ServerArgs.grammar_backend,
|
||||||
help="Choose the backend for grammar-guided decoding.",
|
help="Choose the backend for grammar-guided decoding.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
|||||||
|
|
||||||
|
|
||||||
class CustomTestCase(unittest.TestCase):
|
class CustomTestCase(unittest.TestCase):
|
||||||
pass
|
|
||||||
|
|
||||||
"""
|
|
||||||
def _callTestMethod(self, method):
|
def _callTestMethod(self, method):
|
||||||
max_retry = int(
|
max_retry = int(
|
||||||
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
|
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
|
||||||
@@ -1023,4 +1020,3 @@ class CustomTestCase(unittest.TestCase):
|
|||||||
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
||||||
max_retry=max_retry,
|
max_retry=max_retry,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ CI_LORA_MODELS = [
|
|||||||
],
|
],
|
||||||
max_loras_per_batch=1,
|
max_loras_per_batch=1,
|
||||||
),
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ALL_OTHER_LORA_MODELS = [
|
||||||
LoRAModelCase(
|
LoRAModelCase(
|
||||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
adaptors=[
|
adaptors=[
|
||||||
@@ -43,9 +46,6 @@ CI_LORA_MODELS = [
|
|||||||
],
|
],
|
||||||
max_loras_per_batch=1,
|
max_loras_per_batch=1,
|
||||||
),
|
),
|
||||||
]
|
|
||||||
|
|
||||||
ALL_OTHER_LORA_MODELS = [
|
|
||||||
LoRAModelCase(
|
LoRAModelCase(
|
||||||
base="meta-llama/Llama-2-7b-hf",
|
base="meta-llama/Llama-2-7b-hf",
|
||||||
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
|
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ suites = {
|
|||||||
TestFile("models/lora/test_lora.py", 76),
|
TestFile("models/lora/test_lora.py", 76),
|
||||||
TestFile("models/lora/test_lora_backend.py", 420),
|
TestFile("models/lora/test_lora_backend.py", 420),
|
||||||
TestFile("models/lora/test_multi_lora_backend.py", 144),
|
TestFile("models/lora/test_multi_lora_backend.py", 144),
|
||||||
TestFile("models/test_embedding_models.py", 119),
|
TestFile("models/test_embedding_models.py", 35),
|
||||||
TestFile("models/test_generation_models.py", 103),
|
TestFile("models/test_generation_models.py", 103),
|
||||||
TestFile("models/test_grok_models.py", 60),
|
TestFile("models/test_grok_models.py", 60),
|
||||||
TestFile("models/test_qwen_models.py", 82),
|
TestFile("models/test_qwen_models.py", 82),
|
||||||
@@ -38,7 +38,7 @@ suites = {
|
|||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
TestFile("test_mla.py", 92),
|
TestFile("test_mla.py", 92),
|
||||||
TestFile("test_mla_deepseek_v3.py", 221),
|
TestFile("test_mla_deepseek_v3.py", 221),
|
||||||
TestFile("test_mla_int8_deepseek_v3.py", 421),
|
TestFile("test_mla_int8_deepseek_v3.py", 522),
|
||||||
TestFile("test_mla_flashinfer.py", 395),
|
TestFile("test_mla_flashinfer.py", 395),
|
||||||
TestFile("test_mla_fp8.py", 93),
|
TestFile("test_mla_fp8.py", 93),
|
||||||
TestFile("test_no_chunked_prefill.py", 126),
|
TestFile("test_no_chunked_prefill.py", 126),
|
||||||
@@ -59,7 +59,7 @@ suites = {
|
|||||||
TestFile("test_srt_endpoint.py", 94),
|
TestFile("test_srt_endpoint.py", 94),
|
||||||
TestFile("test_torch_compile.py", 76),
|
TestFile("test_torch_compile.py", 76),
|
||||||
TestFile("test_torch_compile_moe.py", 85),
|
TestFile("test_torch_compile_moe.py", 85),
|
||||||
TestFile("test_torch_native_attention_backend.py", 149),
|
TestFile("test_torch_native_attention_backend.py", 123),
|
||||||
TestFile("test_torchao.py", 70),
|
TestFile("test_torchao.py", 70),
|
||||||
TestFile("test_triton_attention_kernels.py", 4),
|
TestFile("test_triton_attention_kernels.py", 4),
|
||||||
TestFile("test_triton_attention_backend.py", 134),
|
TestFile("test_triton_attention_backend.py", 134),
|
||||||
@@ -76,6 +76,16 @@ suites = {
|
|||||||
TestFile("test_hicache.py", 60),
|
TestFile("test_hicache.py", 60),
|
||||||
TestFile("test_hicache_mla.py", 90),
|
TestFile("test_hicache_mla.py", 90),
|
||||||
],
|
],
|
||||||
|
"per-commit-2-gpu": [
|
||||||
|
TestFile("test_data_parallelism.py", 90),
|
||||||
|
TestFile("test_dp_attention.py", 90),
|
||||||
|
TestFile("test_update_weights_from_distributed.py", 100),
|
||||||
|
TestFile("test_verl_engine.py", 100),
|
||||||
|
TestFile("test_patch_torch.py", 30),
|
||||||
|
TestFile("test_moe_ep.py", 220),
|
||||||
|
TestFile("test_mla_tp.py", 420),
|
||||||
|
TestFile("test_lora_tp.py", 300),
|
||||||
|
],
|
||||||
"nightly": [
|
"nightly": [
|
||||||
TestFile("test_nightly_gsm8k_eval.py"),
|
TestFile("test_nightly_gsm8k_eval.py"),
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -60,3 +60,7 @@ class TestDPAttentionDP2TP2(CustomTestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
print(f"{metrics=}")
|
print(f"{metrics=}")
|
||||||
self.assertGreater(metrics["score"], 0.8)
|
self.assertGreater(metrics["score"], 0.8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ class TestEnableMetrics(CustomTestCase):
|
|||||||
"sglang:cached_tokens_total",
|
"sglang:cached_tokens_total",
|
||||||
"sglang:num_requests_total",
|
"sglang:num_requests_total",
|
||||||
"sglang:time_to_first_token_seconds",
|
"sglang:time_to_first_token_seconds",
|
||||||
"sglang:time_per_output_token_seconds",
|
|
||||||
"sglang:inter_token_latency_seconds",
|
"sglang:inter_token_latency_seconds",
|
||||||
"sglang:e2e_request_latency_seconds",
|
"sglang:e2e_request_latency_seconds",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user