Standalone speculative decoding (#10090)
This commit is contained in:
@@ -1539,7 +1539,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
|
|
||||||
if self.spec_algorithm.is_eagle():
|
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||||
# if spec decoding is used, the decode batch is prepared inside
|
# if spec decoding is used, the decode batch is prepared inside
|
||||||
# `forward_batch_speculative_generation` after running draft models.
|
# `forward_batch_speculative_generation` after running draft models.
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -349,6 +349,18 @@ class Scheduler(
|
|||||||
target_worker=self.tp_worker,
|
target_worker=self.tp_worker,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
)
|
)
|
||||||
|
elif self.spec_algorithm.is_standalone():
|
||||||
|
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
||||||
|
|
||||||
|
self.draft_worker = StandaloneWorker(
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
moe_ep_rank=moe_ep_rank,
|
||||||
|
server_args=server_args,
|
||||||
|
nccl_port=port_args.nccl_port,
|
||||||
|
target_worker=self.tp_worker,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.draft_worker = None
|
self.draft_worker = None
|
||||||
|
|
||||||
|
|||||||
@@ -271,7 +271,10 @@ class CudaGraphRunner:
|
|||||||
self.capture_forward_mode = ForwardMode.DECODE
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
self.num_tokens_per_bs = 1
|
self.num_tokens_per_bs = 1
|
||||||
if model_runner.spec_algorithm.is_eagle():
|
if (
|
||||||
|
model_runner.spec_algorithm.is_eagle()
|
||||||
|
or model_runner.spec_algorithm.is_standalone()
|
||||||
|
):
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
raise RuntimeError("This should not happen")
|
raise RuntimeError("This should not happen")
|
||||||
else:
|
else:
|
||||||
@@ -827,7 +830,10 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def get_spec_info(self, num_tokens: int):
|
def get_spec_info(self, num_tokens: int):
|
||||||
spec_info = None
|
spec_info = None
|
||||||
if self.model_runner.spec_algorithm.is_eagle():
|
if (
|
||||||
|
self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
or self.model_runner.spec_algorithm.is_standalone()
|
||||||
|
):
|
||||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||||
|
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
|
|||||||
@@ -473,9 +473,14 @@ class ServerArgs:
|
|||||||
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
|
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
|
||||||
reserved_mem = 32 * 1024
|
reserved_mem = 32 * 1024
|
||||||
|
|
||||||
|
# draft model and larger cuda graph buffers
|
||||||
if self.speculative_algorithm is not None:
|
if self.speculative_algorithm is not None:
|
||||||
# draft model and larger cuda graph buffers
|
if self.speculative_algorithm == "STANDALONE":
|
||||||
reserved_mem += 2 * 1024
|
# Standalone speculative decoding needs more memory than other speculative
|
||||||
|
# decoding algorithms since the draft model is typically larger.
|
||||||
|
reserved_mem += 6 * 1024
|
||||||
|
else:
|
||||||
|
reserved_mem += 2 * 1024
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
reserved_mem += 4 * 1024
|
reserved_mem += 4 * 1024
|
||||||
|
|
||||||
@@ -704,7 +709,12 @@ class ServerArgs:
|
|||||||
# NEXTN shares the same implementation of EAGLE
|
# NEXTN shares the same implementation of EAGLE
|
||||||
self.speculative_algorithm = "EAGLE"
|
self.speculative_algorithm = "EAGLE"
|
||||||
|
|
||||||
if self.speculative_algorithm in ("EAGLE", "EAGLE3"):
|
if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"):
|
||||||
|
if self.speculative_algorithm == "STANDALONE":
|
||||||
|
# TODO: support dp attention for standalone speculative decoding
|
||||||
|
assert (
|
||||||
|
self.enable_dp_attention is False
|
||||||
|
), "Currently standalone speculative decoding does not support dp attention."
|
||||||
if self.max_running_requests is None:
|
if self.max_running_requests is None:
|
||||||
self.max_running_requests = 48
|
self.max_running_requests = 48
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
@@ -1499,7 +1509,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-algorithm",
|
"--speculative-algorithm",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["EAGLE", "EAGLE3", "NEXTN"],
|
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"],
|
||||||
help="Speculative algorithm.",
|
help="Speculative algorithm.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -2635,7 +2645,9 @@ def auto_choose_speculative_params(self: ServerArgs):
|
|||||||
"""
|
"""
|
||||||
hf_config = self.get_hf_config()
|
hf_config = self.get_hf_config()
|
||||||
arch = hf_config.architectures[0]
|
arch = hf_config.architectures[0]
|
||||||
|
if self.speculative_algorithm == "STANDALONE":
|
||||||
|
# The default value for standalone speculative decoding
|
||||||
|
return (3, 1, 4)
|
||||||
if arch in ["LlamaForCausalLM"]:
|
if arch in ["LlamaForCausalLM"]:
|
||||||
# The default value for llama
|
# The default value for llama
|
||||||
return (5, 4, 8)
|
return (5, 4, 8)
|
||||||
|
|||||||
@@ -341,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
||||||
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
||||||
self.positions[:num_tokens].copy_(forward_batch.positions)
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
||||||
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
if (
|
||||||
|
forward_batch.spec_info.hidden_states.shape[1]
|
||||||
|
== self.hidden_states.shape[1]
|
||||||
|
):
|
||||||
|
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
||||||
if forward_batch.spec_info.accept_length is not None:
|
if forward_batch.spec_info.accept_length is not None:
|
||||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
|
|||||||
@@ -730,6 +730,14 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Set inputs
|
# Set inputs
|
||||||
forward_batch.input_ids = input_ids
|
forward_batch.input_ids = input_ids
|
||||||
|
# This is a temporary fix for the case that the user is using standalone
|
||||||
|
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
|
||||||
|
# rope kernel needs cache_loc to be contiguous.
|
||||||
|
if (
|
||||||
|
self.server_args.speculative_algorithm == "STANDALONE"
|
||||||
|
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
|
||||||
|
):
|
||||||
|
out_cache_loc = out_cache_loc.contiguous()
|
||||||
forward_batch.out_cache_loc = out_cache_loc[i]
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
||||||
forward_batch.positions.add_(1)
|
forward_batch.positions.add_(1)
|
||||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
NONE = auto()
|
NONE = auto()
|
||||||
EAGLE = auto()
|
EAGLE = auto()
|
||||||
EAGLE3 = auto()
|
EAGLE3 = auto()
|
||||||
|
STANDALONE = auto()
|
||||||
|
|
||||||
def is_none(self):
|
def is_none(self):
|
||||||
return self == SpeculativeAlgorithm.NONE
|
return self == SpeculativeAlgorithm.NONE
|
||||||
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
def is_eagle3(self):
|
def is_eagle3(self):
|
||||||
return self == SpeculativeAlgorithm.EAGLE3
|
return self == SpeculativeAlgorithm.EAGLE3
|
||||||
|
|
||||||
|
def is_standalone(self):
|
||||||
|
return self == SpeculativeAlgorithm.STANDALONE
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_string(name: str):
|
def from_string(name: str):
|
||||||
name_map = {
|
name_map = {
|
||||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||||
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
||||||
|
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
||||||
None: SpeculativeAlgorithm.NONE,
|
None: SpeculativeAlgorithm.NONE,
|
||||||
}
|
}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
|
|||||||
109
python/sglang/srt/speculative/standalone_worker.py
Normal file
109
python/sglang/srt/speculative/standalone_worker.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||||
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
||||||
|
|
||||||
|
if is_cuda():
|
||||||
|
from sgl_kernel import segment_packbits
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def draft_tp_context(tp_group: GroupCoordinator):
|
||||||
|
# Draft model doesn't use dp and has its own tp group.
|
||||||
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||||
|
with patch_tensor_parallel_group(tp_group):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class StandaloneWorker(EAGLEWorker):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
gpu_id: int,
|
||||||
|
tp_rank: int,
|
||||||
|
dp_rank: Optional[int],
|
||||||
|
moe_ep_rank: int,
|
||||||
|
nccl_port: int,
|
||||||
|
target_worker: TpModelWorker,
|
||||||
|
):
|
||||||
|
# Parse arguments
|
||||||
|
self.server_args = server_args
|
||||||
|
self.topk = server_args.speculative_eagle_topk
|
||||||
|
self.speculative_num_steps = server_args.speculative_num_steps
|
||||||
|
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||||
|
self.enable_nan_detection = server_args.enable_nan_detection
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
self.device = server_args.device
|
||||||
|
self.target_worker = target_worker
|
||||||
|
self.page_size = server_args.page_size
|
||||||
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
|
server_args.speculative_algorithm
|
||||||
|
)
|
||||||
|
self.padded_static_len = -1
|
||||||
|
|
||||||
|
# Override the context length of the draft model to be the same as the target model.
|
||||||
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
|
|
||||||
|
# Do not capture cuda graph in `super().__init__()`
|
||||||
|
# It will be captured later.
|
||||||
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||||
|
server_args.disable_cuda_graph = True
|
||||||
|
# Share the allocator with a target worker.
|
||||||
|
# Draft and target worker own their own KV cache pools.
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||||
|
target_worker.get_memory_pool()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load hot token ids
|
||||||
|
if server_args.speculative_token_map is not None:
|
||||||
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||||
|
server_args.json_model_override_args = (
|
||||||
|
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.hot_token_id = None
|
||||||
|
|
||||||
|
# Init draft worker
|
||||||
|
with empty_context():
|
||||||
|
TpModelWorker.__init__(
|
||||||
|
self,
|
||||||
|
server_args=server_args,
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
pp_rank=0, # FIXME
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
moe_ep_rank=moe_ep_rank,
|
||||||
|
nccl_port=nccl_port,
|
||||||
|
is_draft_worker=True,
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init attention backend and cuda graphs
|
||||||
|
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||||
|
backup_disable_cuda_graph
|
||||||
|
)
|
||||||
|
self.draft_tp_context = (
|
||||||
|
draft_tp_context if server_args.enable_dp_attention else empty_context
|
||||||
|
)
|
||||||
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
|
self.init_attention_backend()
|
||||||
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
|
# Some dummy tensors
|
||||||
|
self.num_new_pages_per_topk = torch.empty(
|
||||||
|
(), dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
||||||
@@ -72,6 +72,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
|
|||||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
||||||
|
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
)
|
||||||
|
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
# Other use cases
|
# Other use cases
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ suites = {
|
|||||||
TestFile("test_harmony_parser.py", 20),
|
TestFile("test_harmony_parser.py", 20),
|
||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
TestFile("test_hybrid_attn_backend.py", 100),
|
TestFile("test_hybrid_attn_backend.py", 100),
|
||||||
|
TestFile("test_standalone_speculative_decoding.py", 250),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
TestFile("test_io_struct.py", 8),
|
TestFile("test_io_struct.py", 8),
|
||||||
TestFile("test_jinja_template_utils.py", 1),
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
|
|||||||
115
test/srt/test_standalone_speculative_decoding.py
Normal file
115
test/srt/test_standalone_speculative_decoding.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
GSM_DATASET_PATH = None
|
||||||
|
|
||||||
|
|
||||||
|
# Default server arguments shared across all tests
|
||||||
|
DEFAULT_SERVER_ARGS = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"8",
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"STANDALONE",
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"4",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"2",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"7",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
|
||||||
|
|
||||||
|
model = DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
|
||||||
|
draft_model = DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
|
||||||
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
accuracy_threshold = 0.7 # derived tests need to override this
|
||||||
|
spec_decode_threshold = 3.6 # derived spec decoding tests need to override this
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
"""Return the arguments for the server launch. Override in subclasses."""
|
||||||
|
return DEFAULT_SERVER_ARGS + ["--attention-backend", "fa3"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
# disable deep gemm precompile to make launch server faster
|
||||||
|
# please don't do this if you want to make your inference workload faster
|
||||||
|
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
|
||||||
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
|
||||||
|
model = cls.model
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=cls.get_server_args(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=4,
|
||||||
|
num_questions=100,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
data_path=GSM_DATASET_PATH,
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
|
||||||
|
# Use the appropriate metric key based on the test class
|
||||||
|
metric_key = "accuracy"
|
||||||
|
self.assertGreater(metrics[metric_key], self.accuracy_threshold)
|
||||||
|
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandaloneSpeculativeDecodingFlashinfer(
|
||||||
|
TestStandaloneSpeculativeDecodingBase
|
||||||
|
):
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user