Misc fix for min_p_sampling, --cuda-graph-bs (#2761)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -66,12 +66,14 @@ jobs:
|
|||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 25
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
|
||||||
RANGE=${{ matrix.range }}
|
RANGE=${{ matrix.range }}
|
||||||
range_begin=${RANGE%-*}
|
range_begin=${RANGE%-*}
|
||||||
range_end=${RANGE#*-}
|
range_end=${RANGE#*-}
|
||||||
|
|
||||||
|
cd test/srt
|
||||||
python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
|
python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
|
||||||
|
|
||||||
|
|
||||||
unit-test-backend-2-gpu:
|
unit-test-backend-2-gpu:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
runs-on: 2-gpu-runner
|
runs-on: 2-gpu-runner
|
||||||
|
|||||||
@@ -228,6 +228,7 @@ class BenchmarkWorker:
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
topk,
|
topk,
|
||||||
dtype_str,
|
dtype_str,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||||
|
|||||||
@@ -16,14 +16,20 @@ classifiers = [
|
|||||||
dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
|
dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
runtime_common = ["aiohttp", "decord", "fastapi",
|
runtime_common = [
|
||||||
|
"aiohttp", "decord", "fastapi",
|
||||||
"hf_transfer", "huggingface_hub", "interegular", "modelscope",
|
"hf_transfer", "huggingface_hub", "interegular", "modelscope",
|
||||||
"orjson", "outlines>=0.0.44,<0.1.0",
|
"orjson", "outlines>=0.0.44,<0.1.0",
|
||||||
"packaging", "pillow", "prometheus-client>=0.20.0",
|
"packaging", "pillow", "prometheus-client>=0.20.0",
|
||||||
"psutil", "pydantic", "python-multipart",
|
"psutil", "pydantic", "python-multipart",
|
||||||
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
|
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
|
||||||
"xgrammar>=0.1.6"]
|
"xgrammar>=0.1.6"
|
||||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post11"]
|
]
|
||||||
|
srt = [
|
||||||
|
"sglang[runtime_common]", "cuda-python",
|
||||||
|
"sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
||||||
|
"flashinfer==0.1.6"
|
||||||
|
]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||||
|
|||||||
@@ -563,7 +563,7 @@ def sample_sharegpt_requests(
|
|||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
# Download sharegpt if necessary
|
# Download sharegpt if necessary
|
||||||
if not os.path.isfile(dataset_path):
|
if not os.path.isfile(dataset_path) and dataset_path == "":
|
||||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
@@ -1064,8 +1064,11 @@ async def benchmark(
|
|||||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||||
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||||
"median_ttft_ms": metrics.median_ttft_ms,
|
"median_ttft_ms": metrics.median_ttft_ms,
|
||||||
|
"mean_itl_ms": metrics.mean_itl_ms,
|
||||||
"median_itl_ms": metrics.median_itl_ms,
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
|
"input_throughput": metrics.input_throughput,
|
||||||
"output_throughput": metrics.output_throughput,
|
"output_throughput": metrics.output_throughput,
|
||||||
"sharegpt_output_len": args.sharegpt_output_len,
|
"sharegpt_output_len": args.sharegpt_output_len,
|
||||||
"random_input_len": args.random_input_len,
|
"random_input_len": args.random_input_len,
|
||||||
|
|||||||
@@ -117,6 +117,11 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.final_logit_softcapping = getattr(
|
self.final_logit_softcapping = getattr(
|
||||||
self.config, "final_logit_softcapping", None
|
self.config, "final_logit_softcapping", None
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
self.final_logit_softcapping is not None
|
||||||
|
and self.final_logit_softcapping < 0
|
||||||
|
):
|
||||||
|
self.final_logit_softcapping = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1011,6 +1011,17 @@ def fused_experts_impl(
|
|||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if topk_ids.shape[1] == 1:
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
||||||
|
intermediate_cache3[:, 0]
|
||||||
|
)
|
||||||
|
elif topk_ids.shape[1] == 2:
|
||||||
|
torch.add(
|
||||||
|
intermediate_cache3[:, 0],
|
||||||
|
intermediate_cache3[:, 1],
|
||||||
|
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
).squeeze(dim=1)
|
||||||
|
elif topk_ids.shape[1] > 2:
|
||||||
torch.sum(
|
torch.sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
dim=1,
|
dim=1,
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||||
|
|
||||||
from typing import Callable, Dict, Optional, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import threading
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import setproctitle
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
):
|
):
|
||||||
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
|
||||||
|
|||||||
@@ -1516,8 +1516,9 @@ class Scheduler:
|
|||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
def update_weights_from_distributed(
|
||||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
self,
|
||||||
):
|
recv_req: UpdateWeightsFromDistributedReqInput,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
"""Update the online model parameter."""
|
"""Update the online model parameter."""
|
||||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||||
if success:
|
if success:
|
||||||
|
|||||||
@@ -114,26 +114,20 @@ class TokenizerMetricsCollector:
|
|||||||
documentation="Histogram of time to first token in seconds.",
|
documentation="Histogram of time to first token in seconds.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
0.001,
|
|
||||||
0.005,
|
|
||||||
0.01,
|
|
||||||
0.02,
|
|
||||||
0.04,
|
|
||||||
0.06,
|
|
||||||
0.08,
|
|
||||||
0.1,
|
0.1,
|
||||||
0.25,
|
0.25,
|
||||||
0.5,
|
0.5,
|
||||||
0.75,
|
0.75,
|
||||||
1.0,
|
1,
|
||||||
2.5,
|
2,
|
||||||
5.0,
|
5,
|
||||||
7.5,
|
10,
|
||||||
10.0,
|
20,
|
||||||
15.0,
|
40,
|
||||||
20.0,
|
60,
|
||||||
25.0,
|
80,
|
||||||
30.0,
|
120,
|
||||||
|
160,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,21 +162,19 @@ class TokenizerMetricsCollector:
|
|||||||
documentation="Histogram of End-to-end request latency in seconds",
|
documentation="Histogram of End-to-end request latency in seconds",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
0.3,
|
0.1,
|
||||||
|
0.25,
|
||||||
0.5,
|
0.5,
|
||||||
0.8,
|
1,
|
||||||
1.0,
|
2,
|
||||||
1.5,
|
5,
|
||||||
2.0,
|
10,
|
||||||
2.5,
|
20,
|
||||||
5.0,
|
40,
|
||||||
10.0,
|
60,
|
||||||
15.0,
|
80,
|
||||||
20.0,
|
120,
|
||||||
30.0,
|
160,
|
||||||
40.0,
|
|
||||||
50.0,
|
|
||||||
60.0,
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -124,6 +124,13 @@ class CudaGraphRunner:
|
|||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
|
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
||||||
|
if self.capture_bs is None:
|
||||||
|
if model_runner.server_args.disable_cuda_graph_padding:
|
||||||
|
self.capture_bs = list(range(1, 33)) + [64, 128]
|
||||||
|
else:
|
||||||
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
|
|
||||||
if model_runner.server_args.disable_cuda_graph_padding:
|
if model_runner.server_args.disable_cuda_graph_padding:
|
||||||
self.capture_bs = list(range(1, 33)) + [64, 128]
|
self.capture_bs = list(range(1, 33)) + [64, 128]
|
||||||
else:
|
else:
|
||||||
@@ -340,8 +347,8 @@ class CudaGraphRunner:
|
|||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
global_num_tokens=global_num_tokens,
|
global_num_tokens=global_num_tokens,
|
||||||
mrope_positions=mrope_positions,
|
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ class ModelRunner:
|
|||||||
self.is_draft_worker = is_draft_worker
|
self.is_draft_worker = is_draft_worker
|
||||||
self.is_generation = model_config.is_generation
|
self.is_generation = model_config.is_generation
|
||||||
self.is_multimodal = model_config.is_multimodal
|
self.is_multimodal = model_config.is_multimodal
|
||||||
|
self.should_log = tp_rank == 0
|
||||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
@@ -117,15 +118,21 @@ class ModelRunner:
|
|||||||
|
|
||||||
if self.is_multimodal:
|
if self.is_multimodal:
|
||||||
self.mem_fraction_static *= 0.95
|
self.mem_fraction_static *= 0.95
|
||||||
|
logger.info(
|
||||||
|
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
||||||
|
f"because this is a multimodal model."
|
||||||
|
)
|
||||||
|
|
||||||
if self.model_config.hf_config.architectures == [
|
if self.model_config.hf_config.architectures == [
|
||||||
"MllamaForConditionalGeneration"
|
"MllamaForConditionalGeneration"
|
||||||
]:
|
]:
|
||||||
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
||||||
server_args.chunked_prefill_size = -1
|
server_args.chunked_prefill_size = -1
|
||||||
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
|
||||||
if self.model_config.hf_config.architectures == [
|
if self.model_config.hf_config.architectures == [
|
||||||
"Qwen2VLForConditionalGeneration"
|
"Qwen2VLForConditionalGeneration"
|
||||||
]:
|
]:
|
||||||
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
||||||
logger.info(
|
logger.info(
|
||||||
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -232,6 +232,7 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||||
|
|
||||||
def apply_logits_bias(self, logits: torch.Tensor):
|
def apply_logits_bias(self, logits: torch.Tensor):
|
||||||
# Apply logit_bias
|
# Apply logit_bias
|
||||||
|
|||||||
@@ -127,14 +127,12 @@ async def health() -> Response:
|
|||||||
async def health_generate(request: Request) -> Response:
|
async def health_generate(request: Request) -> Response:
|
||||||
"""Check the health of the inference server by generating one token."""
|
"""Check the health of the inference server by generating one token."""
|
||||||
|
|
||||||
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||||
|
|
||||||
if tokenizer_manager.is_generation:
|
if tokenizer_manager.is_generation:
|
||||||
gri = GenerateReqInput(
|
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
|
||||||
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
gri = EmbeddingReqInput(
|
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
|
||||||
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for _ in tokenizer_manager.generate_request(gri, request):
|
async for _ in tokenizer_manager.generate_request(gri, request):
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ class ServerArgs:
|
|||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
|
cuda_graph_bs: Optional[List[int]] = None
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_nan_detection: bool = False
|
enable_nan_detection: bool = False
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
@@ -803,6 +804,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.cuda_graph_max_bs,
|
default=ServerArgs.cuda_graph_max_bs,
|
||||||
help="Set the maximum batch size for cuda graph.",
|
help="Set the maximum batch size for cuda graph.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuda-graph-bs",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
help="Set the list of batch sizes for cuda graph.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torchao-config",
|
"--torchao-config",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -709,13 +709,14 @@ def broadcast_pyobj(
|
|||||||
data: List[Any],
|
data: List[Any],
|
||||||
rank: int,
|
rank: int,
|
||||||
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
|
src: int = 0,
|
||||||
):
|
):
|
||||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
else:
|
else:
|
||||||
serialized_data = pickle.dumps(data)
|
serialized_data = pickle.dumps(data)
|
||||||
size = len(serialized_data)
|
size = len(serialized_data)
|
||||||
@@ -724,19 +725,19 @@ def broadcast_pyobj(
|
|||||||
)
|
)
|
||||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||||
|
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
dist.broadcast(tensor_data, src=src, group=dist_group)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
size = tensor_size.item()
|
size = tensor_size.item()
|
||||||
|
|
||||||
if size == 0:
|
if size == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
dist.broadcast(tensor_data, src=src, group=dist_group)
|
||||||
|
|
||||||
serialized_data = bytes(tensor_data.cpu().numpy())
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||||
data = pickle.loads(serialized_data)
|
data = pickle.loads(serialized_data)
|
||||||
|
|||||||
@@ -532,6 +532,8 @@ def run_bench_serving(
|
|||||||
request_rate,
|
request_rate,
|
||||||
other_server_args,
|
other_server_args,
|
||||||
dataset_name="random",
|
dataset_name="random",
|
||||||
|
dataset_path="",
|
||||||
|
tokenizer=None,
|
||||||
random_input_len=4096,
|
random_input_len=4096,
|
||||||
random_output_len=2048,
|
random_output_len=2048,
|
||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
@@ -553,9 +555,9 @@ def run_bench_serving(
|
|||||||
host=None,
|
host=None,
|
||||||
port=None,
|
port=None,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
dataset_path="",
|
dataset_path=dataset_path,
|
||||||
model=None,
|
model=None,
|
||||||
tokenizer=None,
|
tokenizer=tokenizer,
|
||||||
num_prompts=num_prompts,
|
num_prompts=num_prompts,
|
||||||
sharegpt_output_len=None,
|
sharegpt_output_len=None,
|
||||||
random_input_len=random_input_len,
|
random_input_len=random_input_len,
|
||||||
@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt"
|
|||||||
STDOUT_FILENAME = "stdout.txt"
|
STDOUT_FILENAME = "stdout.txt"
|
||||||
|
|
||||||
|
|
||||||
def read_output(output_lines):
|
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
|
||||||
"""Print the output in real time with another thread."""
|
"""Print the output in real time with another thread."""
|
||||||
while not os.path.exists(STDERR_FILENAME):
|
while not os.path.exists(filename):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
while pt >= 0:
|
while pt >= 0:
|
||||||
if pt > 0 and not os.path.exists(STDERR_FILENAME):
|
if pt > 0 and not os.path.exists(filename):
|
||||||
break
|
break
|
||||||
lines = open(STDERR_FILENAME).readlines()
|
lines = open(filename).readlines()
|
||||||
for line in lines[pt:]:
|
for line in lines[pt:]:
|
||||||
print(line, end="", flush=True)
|
print(line, end="", flush=True)
|
||||||
output_lines.append(line)
|
output_lines.append(line)
|
||||||
@@ -747,6 +749,33 @@ def run_and_check_memory_leak(
|
|||||||
assert has_abort
|
assert has_abort
|
||||||
|
|
||||||
|
|
||||||
|
def run_command_and_capture_output(command, env: Optional[dict] = None):
|
||||||
|
stdout = open(STDOUT_FILENAME, "w")
|
||||||
|
stderr = open(STDERR_FILENAME, "w")
|
||||||
|
process = subprocess.Popen(
|
||||||
|
command, stdout=stdout, stderr=stderr, env=env, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch a thread to stream the output
|
||||||
|
output_lines = []
|
||||||
|
t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
# Join the process
|
||||||
|
process.wait()
|
||||||
|
|
||||||
|
stdout.close()
|
||||||
|
stderr.close()
|
||||||
|
if os.path.exists(STDOUT_FILENAME):
|
||||||
|
os.remove(STDOUT_FILENAME)
|
||||||
|
if os.path.exists(STDERR_FILENAME):
|
||||||
|
os.remove(STDERR_FILENAME)
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
return output_lines
|
||||||
|
|
||||||
|
|
||||||
def run_mmlu_test(
|
def run_mmlu_test(
|
||||||
disable_radix_cache=False,
|
disable_radix_cache=False,
|
||||||
enable_mixed_chunk=False,
|
enable_mixed_chunk=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user