From c0fb25e9493927cfdf09f49fbe2638584600aae3 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Thu, 24 Jul 2025 21:36:21 -0700 Subject: [PATCH] DP Enhancement (#8280) --- .../sglang/srt/distributed/parallel_state.py | 9 + .../srt/layers/attention/base_attn_backend.py | 4 +- python/sglang/srt/layers/communicator.py | 24 +- python/sglang/srt/layers/dp_attention.py | 96 +- python/sglang/srt/layers/logits_processor.py | 58 +- python/sglang/srt/layers/radix_attention.py | 8 +- python/sglang/srt/managers/schedule_batch.py | 5 +- .../srt/model_executor/cuda_graph_runner.py | 86 +- .../srt/model_executor/forward_batch_info.py | 215 +++- .../sglang/srt/model_executor/model_runner.py | 25 +- python/sglang/srt/models/deepseek_v2.py | 3 +- python/sglang/srt/models/qwen2_moe.py | 4 - python/sglang/srt/models/qwen3_moe.py | 7 +- .../eagle_draft_cuda_graph_runner.py | 60 +- .../eagle_draft_extend_cuda_graph_runner.py | 73 +- python/sglang/srt/speculative/eagle_utils.py | 68 +- python/sglang/srt/speculative/eagle_worker.py | 103 +- python/sglang/srt/two_batch_overlap.py | 1 + test/srt/test_deepep_small.py | 12 +- test/srt/test_hybrid_dp_ep_tp_mtp.py | 920 ++---------------- 20 files changed, 665 insertions(+), 1116 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 130bc53c7..45a1a4209 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -545,6 +545,15 @@ class GroupCoordinator: else: torch.distributed.all_reduce(input_, group=self.device_group) + def reduce_scatter_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + ) -> None: + # TODO(ch-wan): support other backends + torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) + return output + def reduce_scatter( self, output: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index bddd7891f..3025d0b11 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -65,7 +65,9 @@ class AttentionBackend(ABC): **kwargs, ): """Run forward on an attention layer.""" - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_idle(): + return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + elif forward_batch.forward_mode.is_decode(): return self.forward_decode( q, k, diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 5e0931ead..aeb8449a1 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -24,8 +24,8 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, + attn_tp_all_gather_into_tensor, + attn_tp_reduce_scatter_tensor, dp_gather_partial, dp_scatter, get_attention_dp_size, @@ -309,8 +309,8 @@ class CommunicateSimpleFn: forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - attn_tp_all_gather( - list(hidden_states.tensor_split(context.attn_tp_size)), + attn_tp_all_gather_into_tensor( + hidden_states, local_hidden_states, ) return hidden_states @@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ].clone(), residual, ) - attn_tp_all_gather( - list(residual.tensor_split(context.attn_tp_size)), local_residual - ) + attn_tp_all_gather_into_tensor(residual, local_residual) if context.attn_dp_size != 1: if context.attn_tp_rank == 0: hidden_states += residual @@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn: *, residual_input_mode, ): - tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) - hidden_states = tensor_list[context.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) + input_hidden_states = hidden_states + hidden_states = hidden_states.tensor_split(context.attn_tp_size)[ + context.attn_tp_rank + ] + attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states) if residual_input_mode == ScatterMode.TP_ATTN_FULL: residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank] if hidden_states.shape[0] != 0: @@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn: forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - attn_tp_all_gather( - list(hidden_states.tensor_split(context.attn_tp_size)), + attn_tp_all_gather_into_tensor( + hidden_states, local_hidden_states, ) return hidden_states, residual diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index ae4041956..55db13336 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -3,7 +3,8 @@ from __future__ import annotations import functools import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, List +from enum import IntEnum, auto +from typing import TYPE_CHECKING, List, Tuple import torch import triton @@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None _LOCAL_ATTN_DP_RANK = None +class DPPaddingMode(IntEnum): + + # Padding tokens to max length and then gather tokens using `all_gather_into_tensor` + MAX_LEN = auto() + # Padding tokens to sum length and then gather tokens using `all_reduce` + SUM_LEN = auto() + + def is_max_len(self): + return self == DPPaddingMode.MAX_LEN + + def is_sum_len(self): + return self == DPPaddingMode.SUM_LEN + + @classmethod + def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode: + # we choose the mode that minimizes the communication cost + max_len = max(global_num_tokens) + sum_len = sum(global_num_tokens) + if sum_len * 2 > max_len * get_attention_dp_size(): + return cls.MAX_LEN + else: + return cls.SUM_LEN + + @classmethod + def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode: + return cls.MAX_LEN + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 @@ -162,7 +191,7 @@ def disable_dp_size(): _ATTN_DP_SIZE = old_dp_size -def get_dp_local_info(forward_batch: ForwardBatch): +def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]: # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here. dp_rank = get_attention_dp_rank() @@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src): memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE) -def _dp_gather( +def _dp_gather_via_all_reduce( global_tokens: torch.Tensor, local_tokens: torch.Tensor, forward_batch: ForwardBatch, @@ -238,13 +267,6 @@ def _dp_gather( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between global_tokens and local_tokens not allowed" - # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1). - # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the - # actual size of the accepted tokens. - if forward_batch.forward_mode.is_draft_extend(): - shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) - local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) - memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False ) @@ -263,6 +285,38 @@ def _dp_gather( global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) +def _dp_gather_via_all_gather( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, + is_partial: bool, +): + if not is_partial: + if get_attention_tp_rank() != 0: + local_tokens.fill_(0) + scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[ + get_attention_tp_rank() + ] + get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens) + get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens) + + +def _dp_gather( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, + is_partial: bool, +): + if forward_batch.dp_padding_mode.is_max_len(): + _dp_gather_via_all_gather( + global_tokens, local_tokens, forward_batch, is_partial + ) + else: + _dp_gather_via_all_reduce( + global_tokens, local_tokens, forward_batch, is_partial + ) + + def dp_gather_partial( global_tokens: torch.Tensor, local_tokens: torch.Tensor, @@ -296,24 +350,18 @@ def dp_scatter( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between local_tokens and global_tokens not allowed" - # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1). - # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the - # actual size of the accepted tokens. - if forward_batch.forward_mode.is_draft_extend(): - shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) - local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) - memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) -def attn_tp_reduce_scatter( - output: torch.Tensor, - input_list: List[torch.Tensor], -): - return get_attention_tp_group().reduce_scatter(output, input_list) +def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_tp_group().reduce_scatter_tensor(output, input) -def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): - return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list) +def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_tp_group().all_gather_into_tensor(output, input) + + +def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor): + return get_attention_tp_group().all_gather(input, output_tensor_list=output_list) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 79d38193e..0aee86f68 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -27,7 +27,9 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( + DPPaddingMode, attn_tp_all_gather, + attn_tp_all_gather_into_tensor, dp_gather_replicate, dp_scatter, get_attention_dp_rank, @@ -111,7 +113,8 @@ class LogitsMetadata: # Number of tokens to sample per DP rank global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None - + # The gather mode for DP attention + dp_padding_mode: Optional[DPPaddingMode] = None # for padding padded_static_len: int = -1 @@ -163,12 +166,12 @@ class LogitsMetadata: forward_batch_gathered_buffer=forward_batch.gathered_buffer, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DPPaddingMode.SUM_LEN, ) - def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): - if self.global_num_tokens_for_logprob_cpu is None: - # we are capturing cuda graph - return + def compute_dp_attention_metadata(self): + # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend, + # we may use a smaller buffer in draft extend. cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) dp_rank = get_attention_dp_rank() @@ -179,18 +182,9 @@ class LogitsMetadata: else: dp_local_start_pos = cumtokens[dp_rank - 1] dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank] - gathered_buffer = torch.zeros( - ( - sum(self.global_num_tokens_for_logprob_cpu), - hidden_states.shape[1], - ), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) self.dp_local_start_pos = dp_local_start_pos self.dp_local_num_tokens = dp_local_num_tokens - self.gathered_buffer = gathered_buffer class LogitsProcessor(nn.Module): @@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module): guarantee the given hidden_states follow this constraint. """ if self.do_tensor_parallel_all_gather_dp_attn: - logits_metadata.compute_dp_attention_metadata(hidden_states) + logits_metadata.compute_dp_attention_metadata() hidden_states, local_hidden_states = ( torch.empty_like(logits_metadata.gathered_buffer), hidden_states, @@ -463,15 +457,31 @@ class LogitsProcessor(nn.Module): if self.do_tensor_parallel_all_gather: if self.use_attn_tp_group: - global_logits = torch.empty( - (self.config.vocab_size, logits.shape[0]), - device=logits.device, - dtype=logits.dtype, - ) - global_logits = global_logits.T - attn_tp_all_gather( - list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits - ) + if self.config.vocab_size % self.attn_tp_size == 0: + global_logits = torch.empty( + ( + self.attn_tp_size, + logits.shape[0], + self.config.vocab_size // self.attn_tp_size, + ), + device=logits.device, + dtype=logits.dtype, + ) + attn_tp_all_gather_into_tensor(global_logits, logits) + global_logits = global_logits.permute(1, 0, 2).reshape( + logits.shape[0], self.config.vocab_size + ) + else: + global_logits = torch.empty( + (self.config.vocab_size, logits.shape[0]), + device=logits.device, + dtype=logits.dtype, + ) + global_logits = global_logits.T + attn_tp_all_gather( + list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), + logits, + ) logits = global_logits else: logits = tensor_model_parallel_all_gather(logits) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 322704ca9..8004fc7c9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -12,14 +12,16 @@ # limitations under the License. # ============================================================================== """Radix attention.""" +from __future__ import annotations from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING, Optional from torch import nn -from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.model_executor.forward_batch_info import ForwardBatch class AttentionType(Enum): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 714af6fba..ea7cad98b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -45,7 +45,6 @@ import triton import triton.language as tl from sglang.global_config import global_config -from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( @@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import flatten_nested_list, support_triton if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -1880,7 +1880,7 @@ class ModelWorkerBatch: sampling_info: SamplingBatchInfo # The input Embeds - input_embeds: Optional[torch.tensor] = None + input_embeds: Optional[torch.Tensor] = None # For corss-encoder model token_type_ids: Optional[torch.Tensor] = None @@ -1890,7 +1890,6 @@ class ModelWorkerBatch: spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None - spec_num_draft_tokens: Optional[int] = None hicache_consumer_index: int = 0 # Overlap event diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 520a631c5..eef7fba14 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): # is very small. We add more values here to make sure we capture the maximum bs. capture_bs += [model_runner.req_to_token_pool.size] + mul_base = 1 + if server_args.enable_two_batch_overlap: - capture_bs = [bs for bs in capture_bs if bs % 2 == 0] + mul_base *= 2 + + if require_gathered_buffer(server_args): + mul_base *= get_attention_tp_size() + + capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] if server_args.cuda_graph_max_bs: capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] @@ -306,20 +313,37 @@ class CudaGraphRunner: self.encoder_lens = None if self.require_gathered_buffer: - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (1,), dtype=torch.int32 + ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + self.gathered_buffer = None self.custom_mask = torch.ones( ( @@ -342,9 +366,9 @@ class CudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.batch_size @@ -480,16 +504,19 @@ class CudaGraphRunner: if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [num_tokens] * self.dp_size, + dtype=torch.int32, + device=input_ids.device, + ) + ) + gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -498,10 +525,15 @@ class CudaGraphRunner: device=input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [num_tokens], + dtype=torch.int32, + device=input_ids.device, + ) + ) gathered_buffer = self.gathered_buffer[:num_tokens] else: - global_num_tokens = None gathered_buffer = None spec_info = self.get_spec_info(num_tokens) @@ -531,7 +563,9 @@ class CudaGraphRunner: encoder_lens=encoder_lens, return_logprob=False, positions=positions, - global_num_tokens_gpu=global_num_tokens, + global_num_tokens_gpu=self.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), gathered_buffer=gathered_buffer, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, @@ -635,12 +669,13 @@ class CudaGraphRunner: # Pad if self.require_mlp_tp_gather: - total_batch_size = ( - sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max_num_tokens ) - index = bisect.bisect_left(self.capture_bs, total_batch_size) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] @@ -670,7 +705,8 @@ class CudaGraphRunner: if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) if self.require_gathered_buffer: - self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) if enable_num_token_non_padded(self.model_runner.server_args): self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) if self.enable_two_batch_overlap: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 6f3ea5474..d6850aabd 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,6 +38,11 @@ import torch import triton import triton.language as tl +from sglang.srt.layers.dp_attention import ( + DPPaddingMode, + get_attention_dp_rank, + get_attention_tp_size, +) from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.utils import ( flatten_nested_list, @@ -48,6 +53,7 @@ from sglang.srt.utils import ( if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend + from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -242,7 +248,7 @@ class ForwardBatch: lora_paths: Optional[List[str]] = None # For input embeddings - input_embeds: Optional[torch.tensor] = None + input_embeds: Optional[torch.Tensor] = None # For cross-encoder model token_type_ids: Optional[torch.Tensor] = None @@ -261,6 +267,8 @@ class ForwardBatch: # Has to be None when cuda graph is captured. global_num_tokens_for_logprob_cpu: Optional[List[int]] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None + # The padding mode for DP attention + dp_padding_mode: Optional[DPPaddingMode] = None # for extend, local start pos and num tokens is different in logits processor # this will be computed in get_dp_local_info # this will be recomputed in LogitsMetadata.from_forward_batch @@ -286,7 +294,7 @@ class ForwardBatch: # For two-batch overlap tbo_split_seq_index: Optional[int] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None - tbo_children: Optional[List["ForwardBatch"]] = None + tbo_children: Optional[List[ForwardBatch]] = None @classmethod def init_new( @@ -340,20 +348,38 @@ class ForwardBatch: len(batch.input_ids), dtype=torch.int32 ).to(device, non_blocking=True) - # For DP attention + # For MLP sync if batch.global_num_tokens is not None: - - spec_num_draft_tokens = ( - batch.spec_num_draft_tokens - if batch.spec_num_draft_tokens is not None - else 1 + from sglang.srt.speculative.eagle_utils import ( + EagleDraftInput, + EagleVerifyInput, ) - global_num_tokens = [ - x * spec_num_draft_tokens for x in batch.global_num_tokens - ] - global_num_tokens_for_logprob = [ - x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob - ] + + assert batch.global_num_tokens_for_logprob is not None + # process global_num_tokens and global_num_tokens_for_logprob + if batch.spec_info is not None: + if isinstance(batch.spec_info, EagleDraftInput): + global_num_tokens = [ + x * batch.spec_info.num_tokens_per_batch + for x in batch.global_num_tokens + ] + global_num_tokens_for_logprob = [ + x * batch.spec_info.num_tokens_for_logprob_per_batch + for x in batch.global_num_tokens_for_logprob + ] + else: + assert isinstance(batch.spec_info, EagleVerifyInput) + global_num_tokens = [ + x * batch.spec_info.draft_token_num + for x in batch.global_num_tokens + ] + global_num_tokens_for_logprob = [ + x * batch.spec_info.draft_token_num + for x in batch.global_num_tokens_for_logprob + ] + else: + global_num_tokens = batch.global_num_tokens + global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_gpu = torch.tensor( @@ -365,15 +391,8 @@ class ForwardBatch: global_num_tokens_for_logprob, dtype=torch.int64 ).to(device, non_blocking=True) - sum_len = sum(global_num_tokens) - ret.gathered_buffer = torch.zeros( - (sum_len, model_runner.model_config.hidden_size), - dtype=model_runner.dtype, - device=device, - ) - if ret.forward_mode.is_idle(): - ret.positions = torch.empty((0,), device=device) + ret.positions = torch.empty((0,), dtype=torch.int64, device=device) TboForwardBatchPreparer.prepare( ret, is_draft_worker=model_runner.is_draft_worker ) @@ -573,6 +592,158 @@ class ForwardBatch: ) self.prefix_chunk_kv_indices.append(chunk_kv_indices) + def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): + if value == 0: + return torch.cat( + [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])], + dim=0, + ) + else: + return torch.cat( + [ + tensor, + tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value), + ], + dim=0, + ) + + def prepare_mlp_sync_batch(self, model_runner: ModelRunner): + + from sglang.srt.speculative.eagle_utils import EagleDraftInput + + assert self.global_num_tokens_cpu is not None + assert self.global_num_tokens_for_logprob_cpu is not None + + global_num_tokens = self.global_num_tokens_cpu + sync_group_size = len(global_num_tokens) + attn_tp_size = get_attention_tp_size() + + for i in range(sync_group_size): + # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. + # there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob + global_num_tokens[i] = ( + (global_num_tokens[i] - 1) // attn_tp_size + 1 + ) * attn_tp_size + + dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens) + self.dp_padding_mode = dp_padding_mode + + if dp_padding_mode.is_max_len(): + # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states, + # where transferred tokens should be padded to the same length. + max_num_tokens = max(global_num_tokens) + global_num_tokens = [max_num_tokens] * sync_group_size + buffer_len = max_num_tokens * sync_group_size + else: + buffer_len = sum(global_num_tokens) + + self.gathered_buffer = torch.zeros( + (buffer_len, model_runner.model_config.hidden_size), + dtype=model_runner.dtype, + device=model_runner.device, + ) + + bs = self.batch_size + if len(global_num_tokens) > 1: + num_tokens = global_num_tokens[get_attention_dp_rank()] + else: + num_tokens = global_num_tokens[0] + + # padding + self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens) + self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs) + + seq_len_fill_value = ( + model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + self.seq_lens = self._pad_tensor_to_size( + self.seq_lens, bs, value=seq_len_fill_value + ) + if self.seq_lens_cpu is not None: + self.seq_lens_cpu = self._pad_tensor_to_size( + self.seq_lens_cpu, bs, value=seq_len_fill_value + ) + + self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens) + if self.encoder_lens is not None: + self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) + self.positions = self._pad_tensor_to_size(self.positions, num_tokens) + self.global_num_tokens_cpu = global_num_tokens + self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor( + global_num_tokens + ) + + if self.mrope_positions is not None: + self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs) + + if self.extend_seq_lens is not None: + self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) + + if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput): + spec_info = self.spec_info + self.output_cache_loc_backup = self.out_cache_loc + self.hidden_states_backup = spec_info.hidden_states + if spec_info.topk_p is not None: + spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs) + if spec_info.topk_index is not None: + spec_info.topk_index = self._pad_tensor_to_size( + spec_info.topk_index, bs + ) + if spec_info.accept_length is not None: + spec_info.accept_length = self._pad_tensor_to_size( + spec_info.accept_length, bs + ) + spec_info.hidden_states = self._pad_tensor_to_size( + spec_info.hidden_states, num_tokens + ) + + def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): + + bs = self.batch_size + + if self.spec_info is not None: + if self.forward_mode.is_decode(): # draft + num_tokens = self.hidden_states_backup.shape[0] + self.positions = self.positions[:num_tokens] + self.seq_lens = self.seq_lens[:bs] + self.req_pool_indices = self.req_pool_indices[:bs] + if self.seq_lens_cpu is not None: + self.seq_lens_cpu = self.seq_lens_cpu[:bs] + logits_output.next_token_logits = logits_output.next_token_logits[ + :num_tokens + ] + logits_output.hidden_states = logits_output.hidden_states[:num_tokens] + elif self.forward_mode.is_target_verify(): # verify + num_tokens = bs * self.spec_info.draft_token_num + logits_output.next_token_logits = logits_output.next_token_logits[ + :num_tokens + ] + logits_output.hidden_states = logits_output.hidden_states[:num_tokens] + elif self.forward_mode.is_draft_extend(): # draft extend + self.spec_info.accept_length = self.spec_info.accept_length[:bs] + logits_output.next_token_logits = logits_output.next_token_logits[:bs] + logits_output.hidden_states = logits_output.hidden_states[:bs] + elif self.forward_mode.is_extend() or self.forward_mode.is_idle(): + logits_output.next_token_logits = logits_output.next_token_logits[:bs] + logits_output.hidden_states = logits_output.hidden_states[:bs] + + if hasattr(self, "hidden_states_backup"): + self.spec_info.hidden_states = self.hidden_states_backup + if hasattr(self, "output_cache_loc_backup"): + self.out_cache_loc = self.output_cache_loc_backup + + elif self.forward_mode.is_decode() or self.forward_mode.is_idle(): + logits_output.next_token_logits = logits_output.next_token_logits[:bs] + if logits_output.hidden_states is not None: + logits_output.hidden_states = logits_output.hidden_states[:bs] + elif self.forward_mode.is_extend(): + num_tokens = self.seq_lens_sum + logits_output.next_token_logits = logits_output.next_token_logits[ + :num_tokens + ] + if logits_output.hidden_states is not None: + logits_output.hidden_states = logits_output.hidden_states[:num_tokens] + # Here we suppose the length of each chunk is equal # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256 # num_prefix_chunks = cdiv(1024, 256) = 4 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cbb35bf27..3d3be71f1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1464,9 +1464,13 @@ class ModelRunner: tensor_parallel(self.model, device_mesh) def forward_decode( - self, forward_batch: ForwardBatch, pp_proxy_tensors=None + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors=None, ) -> LogitsProcessorOutput: - self.attn_backend.init_forward_metadata(forward_batch) + if not skip_attn_backend_init: + self.attn_backend.init_forward_metadata(forward_batch) # FIXME: add pp_proxy_tensors arg to all models kwargs = {} if self.support_pp: @@ -1578,8 +1582,18 @@ class ModelRunner: skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) - elif forward_batch.forward_mode.is_decode(): - ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors) + return ret, can_run_cuda_graph + + # For MLP sync + if forward_batch.global_num_tokens_cpu is not None: + forward_batch.prepare_mlp_sync_batch(self) + + if forward_batch.forward_mode.is_decode(): + ret = self.forward_decode( + forward_batch, + skip_attn_backend_init=skip_attn_backend_init, + pp_proxy_tensors=pp_proxy_tensors, + ) elif forward_batch.forward_mode.is_extend(): ret = self.forward_extend( forward_batch, @@ -1597,6 +1611,9 @@ class ModelRunner: else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") + if forward_batch.global_num_tokens_cpu is not None: + forward_batch.post_forward_mlp_sync_batch(ret) + return ret, can_run_cuda_graph def _preprocess_logits( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e02d30839..7c627bc09 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -550,9 +550,8 @@ class DeepseekV2MoE(nn.Module): def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: - forward_mode = forward_batch.forward_mode shared_output = None - if is_non_idle_and_non_empty(forward_mode, hidden_states): + if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index e033424cf..291678652 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import ( ScatterMode, ) from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, - dp_gather_partial, - dp_scatter, get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c75a38499..8eeee74fa 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, - dp_gather_partial, - dp_scatter, get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, @@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: - forward_mode = forward_batch.forward_mode - if is_non_idle_and_non_empty(forward_mode, hidden_states): + if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) topk_weights, topk_idx, _ = self.topk( diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 6b6c1a777..2c8cdf255 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable import torch +from sglang.srt.layers.dp_attention import DPPaddingMode from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, @@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner: ) if self.require_gathered_buffer: - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 @@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner: self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + self.gathered_buffer = None # Capture try: @@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.batch_size @@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner: if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] + gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( @@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner: return_logprob=False, positions=positions, global_num_tokens_gpu=global_num_tokens, + dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, @@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner: # Pad if self.require_mlp_tp_gather: - total_batch_size = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max_num_tokens ) - index = bisect.bisect_left(self.capture_bs, total_batch_size) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] @@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner: self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + # TODO(ch-wan): support num_token_non_padded if self.require_gathered_buffer: - self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) - self.global_num_tokens_for_logprob_gpu.copy_( - forward_batch.global_num_tokens_for_logprob_gpu - ) - forward_batch.gathered_buffer = self.gathered_buffer + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) # Attention backend if bs != raw_bs: diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 7057c502d..f4ed31d7e 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable import torch +from sglang.srt.layers.dp_attention import DPPaddingMode from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, @@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner: ) if self.require_gathered_buffer: - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 @@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner: self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + self.gathered_buffer = None + # Capture try: with model_capture_mode(): @@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.seq_lens.numel() @@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner: if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [bs] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] - global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu + gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner: ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( - [num_tokens], + [bs], dtype=torch.int32, device=self.input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu gathered_buffer = self.gathered_buffer[:num_tokens] - global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu else: - global_num_tokens = None gathered_buffer = None - global_num_tokens_for_logprob = None spec_info = EagleDraftInput( hidden_states=hidden_states, @@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner: seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, - global_num_tokens_gpu=global_num_tokens, - global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, + global_num_tokens_gpu=self.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, @@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner: raw_bs = forward_batch.batch_size num_tokens = forward_batch.input_ids.shape[0] if self.require_mlp_tp_gather: - total_batch_size = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max_num_tokens ) - index = bisect.bisect_left(self.capture_bs, total_batch_size) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) @@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner: self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + # TODO(ch-wan): support num_token_non_padded if self.require_gathered_buffer: - self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) - self.global_num_tokens_for_logprob_gpu.copy_( - forward_batch.global_num_tokens_for_logprob_gpu - ) - forward_batch.gathered_buffer = self.gathered_buffer + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs) if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 7f7e21e96..aa49e4fc7 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -71,9 +71,20 @@ class EagleDraftInput: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None + # Shape info for padding + num_tokens_per_batch: int = -1 + num_tokens_for_logprob_per_batch: int = -1 + + # Inputs for draft extend + # shape: (b,) + seq_lens_for_draft_extend: torch.Tensor = None + req_pool_indices_for_draft_extend: torch.Tensor = None + def prepare_for_extend(self, batch: ScheduleBatch): + if batch.forward_mode.is_idle(): return + # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) @@ -95,7 +106,7 @@ class EagleDraftInput: capture_hidden_mode: CaptureHiddenMode, ): return cls( - verified_id=None, + verified_id=torch.empty((0,), device=device, dtype=torch.int32), hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), @@ -109,7 +120,10 @@ class EagleDraftInput: batch: ScheduleBatch, speculative_num_steps: int, ): - batch.forward_mode = ForwardMode.DRAFT_EXTEND + + if batch.forward_mode.is_idle(): + return + batch.input_ids = self.verified_id batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) @@ -316,7 +330,7 @@ class EagleVerifyInput: def verify( self, batch: ScheduleBatch, - logits_output: torch.Tensor, + logits_output: LogitsProcessorOutput, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, page_size: int, vocab_mask: Optional[torch.Tensor] = None, # For grammar @@ -599,13 +613,14 @@ class EagleVerifyInput: batch.out_cache_loc = tgt_cache_loc batch.seq_lens.add_(accept_length + 1) - draft_input = EagleDraftInput() - draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] - draft_input.verified_id = verified_id - draft_input.accept_length = accept_length - draft_input.accept_length_cpu = accept_length.tolist() - draft_input.seq_lens_for_draft_extend = batch.seq_lens - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices + draft_input = EagleDraftInput( + hidden_states=batch.spec_info.hidden_states[accept_index], + verified_id=verified_id, + accept_length=accept_length, + accept_length_cpu=accept_length.tolist(), + seq_lens_for_draft_extend=batch.seq_lens, + req_pool_indices_for_draft_extend=batch.req_pool_indices, + ) return EagleVerifyOutput( draft_input=draft_input, @@ -628,7 +643,6 @@ class EagleVerifyInput: batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() - draft_input = EagleDraftInput() if len(unfinished_accept_index) > 0: unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_index_device = torch.tensor( @@ -659,18 +673,26 @@ class EagleVerifyInput: next_power_of_2(self.draft_token_num), ) - draft_input.hidden_states = batch.spec_info.hidden_states[ - unfinished_accept_index - ] - draft_input.verified_id = predict[unfinished_accept_index] - draft_input.accept_length_cpu = draft_input_accept_length_cpu - draft_input.accept_length = accept_length[unfinished_index_device] - draft_input.seq_lens_for_draft_extend = batch.seq_lens[ - unfinished_index_device - ] - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ - unfinished_index_device - ] + draft_input = EagleDraftInput( + hidden_states=batch.spec_info.hidden_states[ + unfinished_accept_index + ], + verified_id=predict[unfinished_accept_index], + accept_length_cpu=draft_input_accept_length_cpu, + accept_length=accept_length[unfinished_index_device], + seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device], + req_pool_indices_for_draft_extend=batch.req_pool_indices[ + unfinished_index_device + ], + ) + else: + draft_input = EagleDraftInput.create_idle_input( + device=batch.device, + hidden_size=batch.model_config.hidden_size, + dtype=batch.model_config.dtype, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) return EagleVerifyOutput( draft_input=draft_input, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index b6a6dace6..2d2e23a01 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker): def forward_batch_speculative_generation( self, batch: ScheduleBatch - ) -> Tuple[LogitsProcessorOutput, List[int], int, int]: + ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]: """Run speculative decoding forward. NOTE: Many states of batch is modified as you go through. It is not guaranteed that @@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker): self.verify(batch, spec_info) ) - if self.check_forward_draft_extend_after_decode(batch): - with self.draft_tp_context(self.draft_model_runner.tp_group): - self.forward_draft_extend_after_decode( - batch, - ) + with self.draft_tp_context(self.draft_model_runner.tp_group): + # NOTE: We should use `check_forward_draft_extend_after_decode` + # when DP attention is enabled, but it is slow. Skip it for now. + if ( + self.server_args.enable_dp_attention + or batch.spec_info.verified_id.shape[0] > 0 + ): + # decode is not finished + self.forward_draft_extend_after_decode(batch) + return ( logits_output, verify_output.verified_id, @@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker): ) def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): - local_need_forward = ( - batch.spec_info.verified_id is not None - and batch.spec_info.verified_id.shape[0] > 0 - ) + local_need_forward = batch.spec_info.verified_id.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward @@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker): def forward_target_extend( self, batch: ScheduleBatch - ) -> Tuple[LogitsProcessorOutput, List[int], int]: + ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]: """Run the target extend. Args: @@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker): # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - model_worker_batch.spec_num_draft_tokens = 1 logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( model_worker_batch ) @@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker): self._draft_preprocess_decode(batch) spec_info = batch.spec_info + assert isinstance(spec_info, EagleDraftInput) spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + spec_info.num_tokens_per_batch = self.topk + spec_info.num_tokens_for_logprob_per_batch = self.topk batch.return_hidden_states = False # Get forward batch model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_num_draft_tokens = self.topk assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner @@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker): forward_batch ) else: + forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): # Initialize attention backend self.draft_attn_backend.init_forward_metadata(forward_batch) @@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker): def draft_forward(self, forward_batch: ForwardBatch): # Parse args spec_info = forward_batch.spec_info + assert isinstance(spec_info, EagleDraftInput) out_cache_loc = forward_batch.out_cache_loc topk_p, topk_index, hidden_states = ( spec_info.topk_p, @@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker): spec_info.hidden_states = hidden_states # Run forward - logits_output = self.draft_model_runner.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + logits_output, _ = self.draft_model_runner.forward( + forward_batch, skip_attn_backend_init=True ) self._detect_nan_if_needed(logits_output) probs = torch.softmax(logits_output.next_token_logits, dim=-1) @@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker): else ForwardMode.IDLE ) batch.spec_info = spec_info + model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu ) - model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode if batch.has_grammar: @@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker): self, batch: ScheduleBatch, hidden_states: torch.Tensor, - next_token_ids: List[int], - seq_lens_cpu: torch.Tensor, + next_token_ids: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], ): """Run draft model extend. This API modifies the states of the batch. @@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker): batch.spec_info = EagleDraftInput( hidden_states=hidden_states, verified_id=next_token_ids, + num_tokens_per_batch=1, + num_tokens_for_logprob_per_batch=1, ) batch.return_hidden_states = False batch.spec_info.prepare_for_extend(batch) @@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker): model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) - model_worker_batch.spec_num_draft_tokens = 1 forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker): self.capture_for_decode(logits_output, forward_batch.spec_info) def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + assert isinstance(batch.spec_info, EagleDraftInput) # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() req_pool_indices_backup = batch.req_pool_indices accept_length_backup = batch.spec_info.accept_length return_logprob_backup = batch.return_logprob + input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle: - # Prepare metadata - if batch.spec_info.verified_id is not None: - batch.spec_info.prepare_extend_after_decode( - batch, - self.speculative_num_steps, - ) - else: - batch = batch.copy() - batch.prepare_for_idle() - hidden_size = ( - self.model_config.hidden_size * 3 - if self.speculative_algorithm.is_eagle3() - else self.model_config.hidden_size - ) - batch.spec_info = EagleDraftInput.create_idle_input( - device=self.device, - hidden_size=hidden_size, - dtype=self.model_config.dtype, - topk=self.topk, - capture_hidden_mode=CaptureHiddenMode.LAST, - ) + + if not input_is_idle and batch.spec_info.verified_id.numel() == 0: + batch = batch.copy() + batch.prepare_for_idle() + hidden_size = ( + self.model_config.hidden_size * 3 + if self.speculative_algorithm.is_eagle3() + else self.model_config.hidden_size + ) + batch.spec_info = EagleDraftInput.create_idle_input( + device=self.device, + hidden_size=hidden_size, + dtype=self.model_config.dtype, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + + batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1 + batch.spec_info.num_tokens_for_logprob_per_batch = 1 + batch.spec_info.prepare_extend_after_decode( + batch, + self.speculative_num_steps, + ) + batch.forward_mode = ( + ForwardMode.DRAFT_EXTEND + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1 assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner @@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker): ) forward_batch.spec_info.hidden_states = logits_output.hidden_states else: + forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): self.draft_model_runner.attn_backend.init_forward_metadata( forward_batch ) - logits_output = self.draft_model_runner.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + logits_output, _ = self.draft_model_runner.forward( + forward_batch, skip_attn_backend_init=True ) self.capture_for_decode(logits_output, forward_batch.spec_info) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 74bc1ba85..e802a7254 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -545,6 +545,7 @@ class TboForwardBatchPreparer: tbo_children=None, global_num_tokens_gpu=None, global_num_tokens_cpu=None, + dp_padding_mode=None, gathered_buffer=gathered_buffer, global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_cpu=None, diff --git a/test/srt/test_deepep_small.py b/test/srt/test_deepep_small.py index e26017ade..0f6ccb955 100644 --- a/test/srt/test_deepep_small.py +++ b/test/srt/test_deepep_small.py @@ -35,7 +35,7 @@ class TestPureDP(CustomTestCase): "--cuda-graph-max-bs", "128", "--max-running-requests", - "128", + "512", "--mem-fraction-static", "0.5", ], @@ -81,7 +81,7 @@ class TestHybridDPTP(CustomTestCase): "--cuda-graph-max-bs", "128", "--max-running-requests", - "128", + "256", ], ) @@ -170,7 +170,7 @@ class TestNoGatherdBuffer(CustomTestCase): "--cuda-graph-max-bs", "32", "--max-running-requests", - "128", + "512", ], ) @@ -217,7 +217,7 @@ class TestTBO(CustomTestCase): "--cuda-graph-max-bs", "128", "--max-running-requests", - "128", + "512", ], ) @@ -273,7 +273,7 @@ class TestMTP(CustomTestCase): "--cuda-graph-max-bs", "32", "--max-running-requests", - "32", + "64", ], ) @@ -343,7 +343,7 @@ class TestMTPWithTBO(CustomTestCase): "--cuda-graph-max-bs", "32", "--max-running-requests", - "32", + "128", ], ) diff --git a/test/srt/test_hybrid_dp_ep_tp_mtp.py b/test/srt/test_hybrid_dp_ep_tp_mtp.py index a3d44a67a..74363649a 100644 --- a/test/srt/test_hybrid_dp_ep_tp_mtp.py +++ b/test/srt/test_hybrid_dp_ep_tp_mtp.py @@ -16,7 +16,7 @@ from sglang.test.test_utils import ( ) -class Test0(CustomTestCase): +class Test00(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -47,23 +47,10 @@ class Test0(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test1(CustomTestCase): +class Test01(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -97,23 +84,10 @@ class Test1(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test2(CustomTestCase): +class Test02(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -147,23 +121,10 @@ class Test2(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test3(CustomTestCase): +class Test03(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -196,23 +157,10 @@ class Test3(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test4(CustomTestCase): +class Test04(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -248,23 +196,10 @@ class Test4(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test5(CustomTestCase): +class Test05(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -300,23 +235,10 @@ class Test5(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test6(CustomTestCase): +class Test06(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -351,23 +273,10 @@ class Test6(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test7(CustomTestCase): +class Test07(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -402,23 +311,10 @@ class Test7(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test8(CustomTestCase): +class Test08(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -455,23 +351,10 @@ class Test8(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test9(CustomTestCase): +class Test09(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -508,20 +391,7 @@ class Test9(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test10(CustomTestCase): @@ -560,20 +430,7 @@ class Test10(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test11(CustomTestCase): @@ -615,20 +472,7 @@ class Test11(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test12(CustomTestCase): @@ -670,20 +514,7 @@ class Test12(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test13(CustomTestCase): @@ -724,20 +555,7 @@ class Test13(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test14(CustomTestCase): @@ -781,20 +599,7 @@ class Test14(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test15(CustomTestCase): @@ -838,20 +643,7 @@ class Test15(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test16(CustomTestCase): @@ -894,20 +686,7 @@ class Test16(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test17(CustomTestCase): @@ -950,20 +729,7 @@ class Test17(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test18(CustomTestCase): @@ -1008,20 +774,7 @@ class Test18(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test19(CustomTestCase): @@ -1066,20 +819,7 @@ class Test19(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test20(CustomTestCase): @@ -1114,20 +854,7 @@ class Test20(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test21(CustomTestCase): @@ -1165,20 +892,7 @@ class Test21(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test22(CustomTestCase): @@ -1216,20 +930,7 @@ class Test22(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test23(CustomTestCase): @@ -1266,20 +967,7 @@ class Test23(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test24(CustomTestCase): @@ -1319,20 +1007,7 @@ class Test24(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test25(CustomTestCase): @@ -1372,20 +1047,7 @@ class Test25(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test26(CustomTestCase): @@ -1424,20 +1086,7 @@ class Test26(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test27(CustomTestCase): @@ -1476,20 +1125,7 @@ class Test27(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test28(CustomTestCase): @@ -1530,20 +1166,7 @@ class Test28(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test29(CustomTestCase): @@ -1584,20 +1207,7 @@ class Test29(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test30(CustomTestCase): @@ -1641,20 +1251,7 @@ class Test30(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test31(CustomTestCase): @@ -1701,20 +1298,7 @@ class Test31(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test32(CustomTestCase): @@ -1761,20 +1345,7 @@ class Test32(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test33(CustomTestCase): @@ -1820,20 +1391,7 @@ class Test33(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test34(CustomTestCase): @@ -1882,20 +1440,7 @@ class Test34(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test35(CustomTestCase): @@ -1944,20 +1489,7 @@ class Test35(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test36(CustomTestCase): @@ -2005,20 +1537,7 @@ class Test36(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test37(CustomTestCase): @@ -2066,20 +1585,7 @@ class Test37(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test38(CustomTestCase): @@ -2129,20 +1635,7 @@ class Test38(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test39(CustomTestCase): @@ -2192,20 +1685,7 @@ class Test39(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test40(CustomTestCase): @@ -2256,20 +1736,7 @@ class Test40(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test41(CustomTestCase): @@ -2323,20 +1790,7 @@ class Test41(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test42(CustomTestCase): @@ -2390,20 +1844,7 @@ class Test42(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test43(CustomTestCase): @@ -2456,20 +1897,7 @@ class Test43(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test44(CustomTestCase): @@ -2525,20 +1953,7 @@ class Test44(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test45(CustomTestCase): @@ -2594,20 +2009,7 @@ class Test45(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test46(CustomTestCase): @@ -2662,20 +2064,7 @@ class Test46(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test47(CustomTestCase): @@ -2730,20 +2119,7 @@ class Test47(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test48(CustomTestCase): @@ -2800,20 +2176,7 @@ class Test48(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test49(CustomTestCase): @@ -2870,20 +2233,7 @@ class Test49(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test50(CustomTestCase): @@ -2928,20 +2278,7 @@ class Test50(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test51(CustomTestCase): @@ -2989,20 +2326,7 @@ class Test51(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test52(CustomTestCase): @@ -3050,20 +2374,7 @@ class Test52(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test53(CustomTestCase): @@ -3110,20 +2421,7 @@ class Test53(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test54(CustomTestCase): @@ -3173,20 +2471,7 @@ class Test54(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test55(CustomTestCase): @@ -3236,20 +2521,7 @@ class Test55(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test56(CustomTestCase): @@ -3298,20 +2570,7 @@ class Test56(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test57(CustomTestCase): @@ -3360,20 +2619,7 @@ class Test57(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test58(CustomTestCase): @@ -3424,20 +2670,7 @@ class Test58(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test59(CustomTestCase): @@ -3488,20 +2721,7 @@ class Test59(CustomTestCase): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) if __name__ == "__main__":