Misc fixes for eagle (flush_cache, CPU overhead) (#3014)
This commit is contained in:
@@ -49,12 +49,13 @@ class BenchArgs:
|
||||
gsp_system_prompt_len: int = 2048
|
||||
gsp_question_len: int = 128
|
||||
gsp_output_len: int = 256
|
||||
seed: int = 1
|
||||
disable_ignore_eos: bool = False
|
||||
extra_request_body: Optional[str] = None
|
||||
seed: int = 1
|
||||
apply_chat_template: bool = False
|
||||
profile: bool = False
|
||||
skip_warmup: bool = False
|
||||
do_not_exit: bool = False
|
||||
profile: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -141,20 +142,31 @@ class BenchArgs:
|
||||
default=BenchArgs.gsp_output_len,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
type=bool,
|
||||
default=BenchArgs.disable_ignore_eos,
|
||||
action="store_true",
|
||||
help="Disable ignore EOS token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
type=str,
|
||||
default=BenchArgs.extra_request_body,
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action="store_true",
|
||||
help="Apply chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-warmup",
|
||||
action="store_true",
|
||||
@@ -165,12 +177,6 @@ class BenchArgs:
|
||||
action="store_true",
|
||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
context_len=args.sharegpt_context_len,
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
@@ -517,6 +518,7 @@ class BenchmarkMetrics:
|
||||
median_e2e_latency_ms: float
|
||||
std_e2e_latency_ms: float
|
||||
p99_e2e_latency_ms: float
|
||||
concurrency: float
|
||||
|
||||
|
||||
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
context_len: Optional[int] = None,
|
||||
apply_chat_template=False,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
|
||||
if apply_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt = prompt.replace(tokenizer.bos_token, "")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer.encode(completion)
|
||||
@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
)
|
||||
|
||||
if prompt_len < 1 or output_len < 1:
|
||||
if prompt_len < 2 or output_len < 2:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
|
||||
@@ -880,6 +892,7 @@ def calculate_metrics(
|
||||
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
||||
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
||||
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
|
||||
concurrency=np.sum(e2e_latencies) / dur_s,
|
||||
)
|
||||
|
||||
return metrics, output_lens
|
||||
@@ -1031,6 +1044,7 @@ async def benchmark(
|
||||
"Total token throughput (tok/s):", metrics.total_throughput
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||
@@ -1062,13 +1076,24 @@ async def benchmark(
|
||||
and metrics.output_throughput is not None
|
||||
):
|
||||
result = {
|
||||
# Arguments
|
||||
"backend": args.backend,
|
||||
"dataset_name": args.dataset_name,
|
||||
"request_rate": request_rate,
|
||||
"max_concurrency": max_concurrency,
|
||||
"sharegpt_output_len": args.sharegpt_output_len,
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
"random_range_ratio": args.random_range_ratio,
|
||||
# Results
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
|
||||
@@ -1085,14 +1110,7 @@ async def benchmark(
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"sharegpt_output_len": args.sharegpt_output_len,
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
"random_range_ratio": args.random_range_ratio,
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"concurrency": metrics.concurrency,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
@@ -1112,36 +1130,16 @@ async def benchmark(
|
||||
with open(output_file_name, "a") as file:
|
||||
file.write(json.dumps(result) + "\n")
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"std_ttft_ms": metrics.std_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"std_tpot_ms": metrics.std_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||
}
|
||||
result.update(
|
||||
{
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -1422,7 +1420,6 @@ if __name__ == "__main__":
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--multi",
|
||||
action="store_true",
|
||||
@@ -1445,16 +1442,17 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Disable streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--return-logprob",
|
||||
action="store_true",
|
||||
help="Return logprob.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
@@ -1462,6 +1460,11 @@ if __name__ == "__main__":
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action="store_true",
|
||||
help="Apply chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
|
||||
@@ -1023,7 +1023,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_jump_forward:
|
||||
if not self.disable_jump_forward and batch.has_grammar:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
@@ -1564,6 +1564,15 @@ class Scheduler:
|
||||
self.grammar_backend.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.draft_worker.model_runner.req_to_token_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
||||
|
||||
self.num_generated_tokens = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.spec_num_total_accepted_tokens = 0
|
||||
self.spec_num_total_forward_ct = 0
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cache flushed successfully!")
|
||||
if_success = True
|
||||
|
||||
@@ -282,6 +282,9 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
attn_backend=model_runner.attn_backend,
|
||||
spec_algorithm=batch.spec_algorithm,
|
||||
spec_info=batch.spec_info,
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
@@ -336,11 +339,6 @@ class ForwardBatch:
|
||||
if model_runner.model_is_mrope:
|
||||
ret.compute_mrope_positions(model_runner, batch)
|
||||
|
||||
# Init attention information
|
||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
ret.attn_backend = model_runner.attn_backend
|
||||
|
||||
# Init lora information
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Some shortcuts for backward compatbility.
|
||||
# Some shortcuts for backward compatibility.
|
||||
# They will be removed in new versions.
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
|
||||
|
||||
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
|
||||
class EAGLEDraftInput(SpecInfo):
|
||||
def __init__(self):
|
||||
self.prev_mode = ForwardMode.DECODE
|
||||
self.sample_output = None
|
||||
|
||||
self.scores: torch.Tensor = None
|
||||
self.score_list: List[torch.Tensor] = []
|
||||
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
|
||||
self.cache_list: List[torch.Tenor] = []
|
||||
self.iter = 0
|
||||
|
||||
# shape: (b, hidden_size)
|
||||
self.hidden_states: torch.Tensor = None
|
||||
# shape: (b,)
|
||||
self.verified_id: torch.Tensor = None
|
||||
# shape: (b, vocab_size)
|
||||
self.sample_output: torch.Tensor = None
|
||||
|
||||
self.positions: torch.Tensor = None
|
||||
self.accept_length: torch.Tensor = None
|
||||
self.has_finished: bool = False
|
||||
self.unfinished_index: List[int] = None
|
||||
self.accept_length_cpu: List[int] = None
|
||||
|
||||
def load_server_args(self, server_args: ServerArgs):
|
||||
self.topk: int = server_args.speculative_eagle_topk
|
||||
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
|
||||
:pre_len
|
||||
] = req.prefix_indices
|
||||
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
|
||||
@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
|
||||
self.cache_list.append(batch.out_cache_loc)
|
||||
self.positions = (
|
||||
batch.seq_lens[:, None]
|
||||
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
||||
+ torch.full(
|
||||
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
|
||||
)
|
||||
).flatten()
|
||||
|
||||
bs = len(batch.seq_lens)
|
||||
@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
|
||||
|
||||
def prepare_extend_after_decode(self, batch: ScheduleBatch):
|
||||
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
||||
batch.extend_lens = (self.accept_length + 1).tolist()
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
seq_lens_cpu = batch.seq_lens.tolist()
|
||||
|
||||
pt = 0
|
||||
seq_lens = batch.seq_lens.tolist()
|
||||
|
||||
i = 0
|
||||
|
||||
for req in batch.reqs:
|
||||
if req.finished():
|
||||
continue
|
||||
# assert seq_len - pre_len == req.extend_input_len
|
||||
input_len = self.accept_length[i] + 1
|
||||
seq_len = seq_lens[i]
|
||||
input_len = batch.extend_lens[i]
|
||||
seq_len = seq_lens_cpu[i]
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
seq_len - input_len : seq_len
|
||||
] = batch.out_cache_loc[pt : pt + input_len]
|
||||
pt += input_len
|
||||
i += 1
|
||||
assert pt == batch.out_cache_loc.shape[0]
|
||||
|
||||
self.positions = torch.empty_like(self.verified_id)
|
||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
||||
@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
|
||||
triton.next_power_of_2(self.spec_steps + 1),
|
||||
)
|
||||
|
||||
batch.seq_lens_sum = sum(batch.seq_lens)
|
||||
batch.seq_lens_sum = sum(seq_lens_cpu)
|
||||
batch.input_ids = self.verified_id
|
||||
self.verified_id = new_verified_id
|
||||
|
||||
@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
|
||||
finished_extend_len = {} # {rid:accept_length + 1}
|
||||
accept_index_cpu = accept_index.tolist()
|
||||
predict_cpu = predict.tolist()
|
||||
has_finished = False
|
||||
|
||||
# iterate every accepted token and check if req has finished after append the token
|
||||
# should be checked BEFORE free kv cache slots
|
||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||
@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
|
||||
finished_extend_len[req.rid] = j + 1
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
draft_input.has_finished = True
|
||||
has_finished = True
|
||||
# set all tokens after finished token to -1 and break
|
||||
accept_index[i, j + 1 :] = -1
|
||||
break
|
||||
@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo):
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
verified_id = predict[accept_index]
|
||||
verified_id_cpu = verified_id.tolist()
|
||||
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo):
|
||||
draft_input.verified_id = predict[new_accept_index]
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
||||
draft_input.accept_length = accept_length[unfinished_index]
|
||||
draft_input.unfinished_index = unfinished_index
|
||||
draft_input.accept_length_cpu = [
|
||||
accept_length_cpu[i] for i in unfinished_index
|
||||
]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||
return (
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
||||
from sglang.srt.utils import rank0_print
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_draft_decode(self, batch: ScheduleBatch):
|
||||
batch.spec_info.prepare_for_decode(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
|
||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_to_token_pool = runner.req_to_token_pool
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
if batch.spec_info.has_finished:
|
||||
index = batch.spec_info.unfinished_index
|
||||
seq_lens = batch.seq_lens
|
||||
batch.seq_lens = batch.seq_lens[index]
|
||||
|
||||
batch.spec_info.prepare_extend_after_decode(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
|
||||
batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
if batch.spec_info.has_finished:
|
||||
batch.seq_lens = seq_lens
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
@@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool:
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def rank0_print(msg: str):
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
print(msg, flush=True)
|
||||
|
||||
@@ -535,7 +535,8 @@ def test_hellaswag_select():
|
||||
|
||||
# Compute accuracy
|
||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.1
|
||||
print(f"{accuracy=}, {accuracy_gen=}")
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.05
|
||||
assert np.abs(latency_gen - latency) < 1
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
@@ -567,15 +567,16 @@ def run_bench_serving(
|
||||
random_range_ratio=0.0,
|
||||
request_rate=request_rate,
|
||||
multi=None,
|
||||
seed=0,
|
||||
output_file=None,
|
||||
disable_tqdm=False,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=False,
|
||||
return_logprob=False,
|
||||
lora_name=None,
|
||||
seed=0,
|
||||
disable_ignore_eos=False,
|
||||
extra_request_body=None,
|
||||
apply_chat_template=False,
|
||||
profile=None,
|
||||
lora_name=None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user