Simplify tests & Fix trtllm custom allreduce registration (#4252)
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -266,7 +266,7 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
|
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
|
||||||
|
|
||||||
USE_VLLM_CUSTOM_ALLREDUCE=0 python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
|
# USE_VLLM_CUSTOM_ALLREDUCE=0 python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
|
||||||
|
|
||||||
- name: Benchmark single latency + torch.compile (TP=2)
|
- name: Benchmark single latency + torch.compile (TP=2)
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.library
|
import torch.library
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip, is_hpu
|
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
use_vllm_custom_allreduce = get_bool_env_var(
|
||||||
|
"USE_VLLM_CUSTOM_ALLREDUCE", default="true"
|
||||||
|
)
|
||||||
|
|
||||||
if not is_hpu():
|
if not is_hpu():
|
||||||
# ROCm does not use vllm custom allreduce
|
# ROCm does not use vllm custom allreduce
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
is_pin_memory_available,
|
is_pin_memory_available,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
@@ -197,7 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
Returns the path to the downloaded model, or None if the model is not
|
Returns the path to the downloaded model, or None if the model is not
|
||||||
downloaded from ModelScope."""
|
downloaded from ModelScope."""
|
||||||
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
|
if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
|
||||||
# download model from ModelScope hub,
|
# download model from ModelScope hub,
|
||||||
# lazy import so that modelscope is not required for normal use.
|
# lazy import so that modelscope is not required for normal use.
|
||||||
# pylint: disable=C.
|
# pylint: disable=C.
|
||||||
|
|||||||
@@ -100,7 +100,6 @@ void cublas_grouped_gemm(
|
|||||||
check_device_dtype(out_dtype, inputs);
|
check_device_dtype(out_dtype, inputs);
|
||||||
check_device_dtype(out_dtype, weights);
|
check_device_dtype(out_dtype, weights);
|
||||||
check_device_dtype(out_dtype, outputs);
|
check_device_dtype(out_dtype, outputs);
|
||||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
|
||||||
|
|
||||||
// Weights should be transposed to (n, k) of column major
|
// Weights should be transposed to (n, k) of column major
|
||||||
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
||||||
@@ -132,7 +131,6 @@ void cublas_grouped_gemm(
|
|||||||
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
||||||
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
||||||
|
|
||||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
|
||||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
|
|
||||||
// Should allocate tensors for storage of pointers
|
// Should allocate tensors for storage of pointers
|
||||||
@@ -141,6 +139,9 @@ void cublas_grouped_gemm(
|
|||||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||||
|
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||||
|
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||||
|
|
||||||
auto status = cublasGemmGroupedBatchedEx(
|
auto status = cublasGemmGroupedBatchedEx(
|
||||||
handle,
|
handle,
|
||||||
transa_array.data(),
|
transa_array.data(),
|
||||||
|
|||||||
@@ -32,11 +32,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||||
|
|
||||||
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||||
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
|
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||||
|
|
||||||
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
|
||||||
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/attention
|
* From csrc/attention
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import math
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,30 +20,33 @@ suites = {
|
|||||||
TestFile("models/test_generation_models.py", 103),
|
TestFile("models/test_generation_models.py", 103),
|
||||||
TestFile("models/test_qwen_models.py", 82),
|
TestFile("models/test_qwen_models.py", 82),
|
||||||
TestFile("models/test_reward_models.py", 83),
|
TestFile("models/test_reward_models.py", 83),
|
||||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
|
||||||
TestFile("models/test_gme_qwen_models.py", 45),
|
TestFile("models/test_gme_qwen_models.py", 45),
|
||||||
TestFile("test_abort.py", 51),
|
TestFile("test_abort.py", 51),
|
||||||
|
TestFile("test_block_int8.py", 22),
|
||||||
TestFile("test_chunked_prefill.py", 336),
|
TestFile("test_chunked_prefill.py", 336),
|
||||||
TestFile("test_custom_allreduce.py", 1),
|
|
||||||
TestFile("test_double_sparsity.py", 50),
|
|
||||||
TestFile("test_eagle_infer.py", 447),
|
TestFile("test_eagle_infer.py", 447),
|
||||||
|
TestFile("test_ebnf_constrained.py"),
|
||||||
|
TestFile("test_fp8_kernel.py", 2),
|
||||||
TestFile("test_embedding_openai_server.py", 36),
|
TestFile("test_embedding_openai_server.py", 36),
|
||||||
TestFile("test_eval_accuracy_mini.py", 63),
|
|
||||||
TestFile("test_gguf.py", 78),
|
TestFile("test_gguf.py", 78),
|
||||||
|
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||||
|
TestFile("test_hidden_states.py", 55),
|
||||||
|
TestFile("test_int8_kernel.py", 1),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
|
TestFile("test_json_constrained.py", 98),
|
||||||
|
TestFile("test_large_max_new_tokens.py", 41),
|
||||||
|
TestFile("test_metrics.py", 32),
|
||||||
TestFile("test_mla.py", 92),
|
TestFile("test_mla.py", 92),
|
||||||
TestFile("test_mla_deepseek_v3.py", 221),
|
TestFile("test_mla_deepseek_v3.py", 221),
|
||||||
TestFile("test_mla_flashinfer.py", 395),
|
TestFile("test_mla_flashinfer.py", 395),
|
||||||
TestFile("test_mla_fp8.py", 93),
|
TestFile("test_mla_fp8.py", 93),
|
||||||
TestFile("test_json_constrained.py", 98),
|
|
||||||
TestFile("test_large_max_new_tokens.py", 41),
|
|
||||||
TestFile("test_metrics.py", 32),
|
|
||||||
TestFile("test_no_chunked_prefill.py", 126),
|
TestFile("test_no_chunked_prefill.py", 126),
|
||||||
TestFile("test_no_overlap_scheduler.py", 262),
|
TestFile("test_no_overlap_scheduler.py", 262),
|
||||||
TestFile("test_openai_server.py", 124),
|
TestFile("test_openai_server.py", 124),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||||
TestFile("test_radix_attention.py", 167),
|
TestFile("test_radix_attention.py", 167),
|
||||||
|
TestFile("test_reasoning_content.py", 89),
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
TestFile("test_release_memory_occupation.py", 44),
|
TestFile("test_release_memory_occupation.py", 44),
|
||||||
TestFile("test_request_length_validation.py", 31),
|
TestFile("test_request_length_validation.py", 31),
|
||||||
@@ -58,7 +61,6 @@ suites = {
|
|||||||
TestFile("test_torchao.py", 70),
|
TestFile("test_torchao.py", 70),
|
||||||
TestFile("test_triton_attention_kernels.py", 4),
|
TestFile("test_triton_attention_kernels.py", 4),
|
||||||
TestFile("test_triton_attention_backend.py", 134),
|
TestFile("test_triton_attention_backend.py", 134),
|
||||||
TestFile("test_hidden_states.py", 55),
|
|
||||||
TestFile("test_update_weights_from_disk.py", 114),
|
TestFile("test_update_weights_from_disk.py", 114),
|
||||||
TestFile("test_update_weights_from_tensor.py", 48),
|
TestFile("test_update_weights_from_tensor.py", 48),
|
||||||
TestFile("test_vertex_endpoint.py", 31),
|
TestFile("test_vertex_endpoint.py", 31),
|
||||||
@@ -66,10 +68,6 @@ suites = {
|
|||||||
TestFile("test_vision_llm.py", 18.4),
|
TestFile("test_vision_llm.py", 18.4),
|
||||||
TestFile("test_vision_openai_server.py", 344),
|
TestFile("test_vision_openai_server.py", 344),
|
||||||
TestFile("test_w8a8_quantization.py", 46),
|
TestFile("test_w8a8_quantization.py", 46),
|
||||||
TestFile("test_fp8_kernel.py", 2),
|
|
||||||
TestFile("test_block_int8.py", 22),
|
|
||||||
TestFile("test_int8_kernel.py", 1),
|
|
||||||
TestFile("test_reasoning_content.py", 89),
|
|
||||||
],
|
],
|
||||||
"nightly": [
|
"nightly": [
|
||||||
TestFile("test_nightly_gsm8k_eval.py"),
|
TestFile("test_nightly_gsm8k_eval.py"),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import unittest
|
|||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||||
|
get_bool_env_var,
|
||||||
is_in_ci,
|
is_in_ci,
|
||||||
run_bench_one_batch,
|
run_bench_one_batch,
|
||||||
write_github_step_summary,
|
write_github_step_summary,
|
||||||
@@ -27,9 +28,13 @@ class TestBenchOneBatch(unittest.TestCase):
|
|||||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2", "--cuda-graph-max-bs", "2"]
|
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2", "--cuda-graph-max-bs", "2"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_vllm_custom_allreduce = get_bool_env_var(
|
||||||
|
"USE_VLLM_CUSTOM_ALLREDUCE", default="true"
|
||||||
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
write_github_step_summary(
|
write_github_step_summary(
|
||||||
f"### test_moe_tp2_bs1\n"
|
f"### test_moe_tp2_bs1 ({use_vllm_custom_allreduce=})\n"
|
||||||
f"output_throughput : {output_throughput:.2f} token/s\n"
|
f"output_throughput : {output_throughput:.2f} token/s\n"
|
||||||
)
|
)
|
||||||
self.assertGreater(output_throughput, 124)
|
self.assertGreater(output_throughput, 124)
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.run_eval import run_eval
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
popen_launch_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=["--log-level-http", "warning", "--chunked-prefill-size", "256"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def test_mmlu(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=3000,
|
|
||||||
num_threads=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
assert metrics["score"] >= 0.705, f"{metrics}"
|
|
||||||
|
|
||||||
def test_human_eval(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="humaneval",
|
|
||||||
num_examples=None,
|
|
||||||
num_threads=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
assert metrics["score"] >= 0.64, f"{metrics}"
|
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mgsm_en",
|
|
||||||
num_examples=None,
|
|
||||||
num_threads=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
assert metrics["score"] >= 0.84, f"{metrics}"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.run_eval import run_eval
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
popen_launch_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=[
|
|
||||||
"--log-level-http",
|
|
||||||
"warning",
|
|
||||||
"--chunked-prefill-size",
|
|
||||||
"256",
|
|
||||||
"--enable-mixed-chunk",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def test_mmlu(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=3000,
|
|
||||||
num_threads=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
assert metrics["score"] >= 0.705, f"{metrics}"
|
|
||||||
|
|
||||||
def test_human_eval(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="humaneval",
|
|
||||||
num_examples=None,
|
|
||||||
num_threads=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
assert metrics["score"] >= 0.64, f"{metrics}"
|
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mgsm_en",
|
|
||||||
num_examples=None,
|
|
||||||
num_threads=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
assert metrics["score"] >= 0.84, f"{metrics}"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.run_eval import run_eval
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
popen_launch_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEvalAccuracyMini(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def test_mmlu(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=64,
|
|
||||||
num_threads=32,
|
|
||||||
temperature=0.1,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -129,6 +129,7 @@ class TestGPTQModelDynamic(unittest.TestCase):
|
|||||||
"text": "The capital of France is",
|
"text": "The capital of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"temperature": 0.001,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user