From 72c77763559317b2c8bddfd67e173b67aa1facb0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 13 Jan 2025 01:39:14 -0800 Subject: [PATCH] Fix linear.py and improve weight loading (#2851) Co-authored-by: SangBin Cho --- benchmark/deepseek_v3/README.md | 7 +- docs/references/supported_models.md | 2 +- python/sglang/srt/layers/linear.py | 134 +++++------------- python/sglang/srt/layers/moe/topk.py | 6 +- python/sglang/srt/layers/parameter.py | 40 +++--- .../srt/layers/quantization/fp8_utils.py | 2 +- .../srt/layers/quantization/modelopt_quant.py | 2 +- .../srt/layers/vocab_parallel_embedding.py | 17 ++- python/sglang/srt/managers/scheduler.py | 4 + python/sglang/srt/mem_cache/memory_pool.py | 19 +++ python/sglang/srt/server.py | 3 + test/srt/test_moe_eval_accuracy_large.py | 2 +- 12 files changed, 113 insertions(+), 125 deletions(-) diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index d14a8d556..5c353bca5 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -39,7 +39,7 @@ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-r For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput. -### Example with OpenAI API +### Example: Sending requests with OpenAI API ```python3 import openai @@ -58,7 +58,8 @@ response = client.chat.completions.create( ) print(response) ``` -### Example serving with 2 H20*8 + +### Example: Serving with two H20*8 nodes For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. ```bash @@ -71,7 +72,7 @@ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --di If you have two H100 nodes, the usage is similar to the aforementioned H20. -### Example serving with Docker two H200*8 nodes +### Example: Serving with two H200*8 nodes and docker There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`. A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage. diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 9dafc3d2a..1cc7b8747 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -5,7 +5,7 @@ - Mistral / Mixtral / Mistral NeMo - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL -- DeepSeek / DeepSeek 2 +- DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3) - OLMoE - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava` diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index b839deeb3..ee9386c13 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1,4 +1,4 @@ -# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py +"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" import logging from abc import abstractmethod @@ -16,7 +16,7 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) -# workaround +# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now. from vllm.model_executor.layers.linear import LinearBase from sglang.srt.layers.parameter import ( @@ -25,7 +25,6 @@ from sglang.srt.layers.parameter import ( PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter, - _ColumnvLLMParameter, ) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -43,9 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "GPTQLinearMethod", "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", ] @@ -95,62 +98,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight -def load_column_qkv_weight( - self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank -): - if ( - isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) - and self.output_dim == self.packed_dim - ): - shard_size, shard_offset = self.adjust_shard_indexes_for_packing( - shard_offset=shard_offset, shard_size=shard_size - ) - - param_data = self.data - shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads - param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) - loaded_weight = loaded_weight.narrow( - self.output_dim, shard_id * shard_size, shard_size - ) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - -def load_column_parallel_weight( - self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False -): - if isinstance(self, _ColumnvLLMParameter): - if not use_presharded_weights: - shard_size = self.data.shape[self.output_dim] - loaded_weight = loaded_weight.narrow( - self.output_dim, tp_rank * shard_size, shard_size - ) - assert self.data.shape == loaded_weight.shape - self.data.copy_(loaded_weight) - else: - self.data.copy_(loaded_weight) - - -def load_row_parallel_weight( - self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False -): - if isinstance(self, RowvLLMParameter): - if not use_presharded_weights: - shard_size = self.data.shape[self.input_dim] - loaded_weight = loaded_weight.narrow( - self.input_dim, tp_rank * shard_size, shard_size - ) - - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - - assert self.data.shape == loaded_weight.shape - self.data.copy_(loaded_weight) - else: - self.data.copy_(loaded_weight) - - class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -426,9 +373,7 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert ( - param_data.shape == loaded_weight.shape - ), f"{param_data.shape=}, {loaded_weight.shape=}" + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): @@ -437,7 +382,7 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight=loaded_weight) + param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -565,9 +510,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data, loaded_weight, 0 ) - assert ( - param_data.shape == loaded_weight.shape - ), f"{param_data.shape=}, {loaded_weight.shape=}" + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return current_shard_offset = 0 @@ -643,9 +586,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): "the same for all partitions." ) - assert ( - param_data.shape == loaded_weight.shape - ), f"{param_data.shape=}, {loaded_weight.shape=}" + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) def _load_fused_module_from_checkpoint( @@ -697,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -882,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear): elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_qkv_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -896,24 +839,14 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n - if isinstance(param, _ColumnvLLMParameter): - load_column_qkv_weight( - param, - loaded_weight, - num_heads=self.num_kv_head_replicas, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size, - tp_rank=self.tp_rank, - ) - else: - param.load_qkv_weight( - loaded_weight=loaded_weight, - num_heads=self.num_kv_head_replicas, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size, - ) + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) def weight_loader( self, @@ -962,9 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear): param_data, loaded_weight, 0 ) - assert ( - param_data.shape == loaded_weight.shape - ), f"{param_data.shape=}, {loaded_weight.shape=}" + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return shard_offsets = [ @@ -1105,9 +1036,7 @@ class QKVParallelLinear(ColumnParallelLinear): "for all partitions." ) - assert ( - param_data.shape == loaded_weight.shape - ), f"{param_data.shape=}, {loaded_weight.shape=}" + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1234,9 +1163,7 @@ class RowParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert ( - param_data.shape == loaded_weight.shape - ), f"{param_data.shape=}, {loaded_weight.shape=}" + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): @@ -1247,7 +1174,18 @@ class RowParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_row_parallel_weight(loaded_weight=loaded_weight) + if isinstance(param, BasevLLMParameter): + # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py, + # It supports additional parameters like tp_rank and use_presharded_weights. + param.load_row_parallel_weight( + loaded_weight, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) + else: + # `params` is defined in `vllm/model_executor/parameter.py`, + # It does not support additional parameters. + param.load_row_parallel_weight(loaded_weight) def forward(self, input_): if self.input_is_parallel: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 819032198..527a7d499 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -24,7 +24,9 @@ def fused_topk_native( topk: int, renormalize: bool, ): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" M, _ = hidden_states.shape topk_weights = torch.empty( M, topk, dtype=torch.float32, device=hidden_states.device @@ -180,7 +182,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, ) - elif torch_native: + elif torch_native and custom_routing_function is None: topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 435cc69bb..fe999baa2 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -1,7 +1,4 @@ -""" -Adapted from vLLM (0.6.4.post1). -https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py -""" +"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py""" import logging from fractions import Fraction @@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter): def output_dim(self): return self._output_dim - def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.data.shape[self.output_dim] - loaded_weight = loaded_weight.narrow( - self.output_dim, tp_rank * shard_size, shard_size - ) + def load_column_parallel_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + ): + if not use_presharded_weights: + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) @@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") @@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter): ) param_data = self.data - tp_rank = get_tensor_model_parallel_rank() shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow( @@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter): def input_dim(self): return self._input_dim - def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs): - use_presharded_weights = kwargs.get("use_presharded_weights") - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.data.shape[self.input_dim] + def load_row_parallel_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + ): if not use_presharded_weights: + shard_size = self.data.shape[self.input_dim] loaded_weight = loaded_weight.narrow( self.input_dim, tp_rank * shard_size, shard_size ) @@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter): # For row parallel layers, no sharding needed # load weight into parameter as is def load_row_parallel_weight(self, *args, **kwargs): + kwargs.pop("tp_rank", None) + kwargs.pop("use_presharded_weights", None) super().load_row_parallel_weight(*args, **kwargs) def load_merged_column_weight(self, *args, **kwargs): @@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter): self._load_into_shard_id(*args, **kwargs) def load_column_parallel_weight(self, *args, **kwargs): + kwargs.pop("tp_rank", None) + kwargs.pop("use_presharded_weights", None) super().load_row_parallel_weight(*args, **kwargs) def _load_into_shard_id( diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 140e70dd9..d6ff12ee1 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,8 +1,8 @@ from typing import List, Optional, Tuple import torch -from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter +from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 8ce9d20d1..5d65899d6 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported, requantize_with_max_scale, ) -from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 21d973918..a346a2cbd 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_tp: bool = True, + use_presharded_weights: bool = False, ): super().__init__() self.quant_config = quant_config @@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module): self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size + self.use_presharded_weights = use_presharded_weights + if use_presharded_weights: + assert ( + num_added_embeddings == 0 + ), "Lora is not supported with presharded weights." + self.org_vocab_size_padded = pad_vocab_size( self.org_vocab_size, self.padding_size ) @@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module): start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: - assert loaded_weight.shape[output_dim] == self.org_vocab_size + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size + // (self.tp_size if self.use_presharded_weights else 1) + ) # Copy the data. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param[: loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0] :].data.fill_(0) @@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding): padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_presharded_weights: bool = False, ): super().__init__( num_embeddings, @@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding): padding_size, quant_config, prefix, + use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config if bias: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 31c8018e2..1c07ea6ad 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,6 +13,7 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" +import faulthandler import logging import os import signal @@ -399,6 +400,8 @@ class Scheduler: self.watchdog_last_time = time.time() time.sleep(self.watchdog_timeout / 2) + # Wait sometimes so that the parent process can print the error. + time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) @torch.no_grad() @@ -1582,6 +1585,7 @@ def run_scheduler_process( pipe_writer, ): setproctitle.setproctitle("sglang::scheduler") + faulthandler.enable() # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 6cb186577..abee7764b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,6 +27,7 @@ from enum import IntEnum from functools import wraps from typing import List, Tuple, Union +import numpy as np import psutil import torch @@ -35,6 +36,8 @@ from sglang.srt.utils import debug_timing, get_compiler_backend logger = logging.getLogger(__name__) +GB = 1024 * 1024 * 1024 + class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" @@ -193,6 +196,11 @@ class MHATokenToKVPool(BaseTokenToKVPool): self.layer_num = layer_num self._create_buffers() + k_size, v_size = self.get_kv_size_bytes() + logger.info( + f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB." + ) + def _create_buffers(self): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. @@ -217,6 +225,17 @@ class MHATokenToKVPool(BaseTokenToKVPool): del self.k_buffer del self.v_buffer + def get_kv_size_bytes(self): + assert hasattr(self, "k_buffer") + assert hasattr(self, "v_buffer") + k_size_bytes = 0 + for k_cache in self.k_buffer: + k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize + v_size_bytes = 0 + for v_cache in self.v_buffer: + v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize + return k_size_bytes, v_size_bytes + # Todo: different memory layout def get_flat_data(self, indices): # prepare a large chunk of contiguous data for efficient transfer diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8fd902818..fa1625b09 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -611,6 +611,9 @@ def _set_envs_and_config(server_args: ServerArgs): # The child processes will send SIGQUIT to this process when any error happens # This process then clean up the whole process tree def sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child proces. It usually means the child failed." + ) kill_process_tree(os.getpid()) signal.signal(signal.SIGQUIT, sigquit_handler) diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 6f3affbba..dc420f00d 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -71,7 +71,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.62) + self.assertGreater(metrics["score"], 0.61) if __name__ == "__main__":