diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 4b518eccc..0e9a77f04 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -15,12 +15,14 @@ # Adapted from llama2.py # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Tuple +import logging +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch from torch import nn from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) @@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, @@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers Qwen2Config = None +logger = logging.getLogger(__name__) + + class Qwen2MLP(nn.Module): def __init__( self, @@ -245,15 +251,21 @@ class Qwen2Model(nn.Module): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("embed_tokens", prefix), - ) + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + # Use the provided decoder layer type or default to Qwen2DecoderLayer decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer - self.layers = make_layers( + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, lambda idx, prefix: decoder_layer_type( layer_id=idx, @@ -261,9 +273,14 @@ class Qwen2Model(nn.Module): quant_config=quant_config, prefix=prefix, ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, prefix=add_prefix("layers", prefix), ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: if hasattr(self.config, "scale_emb"): @@ -280,13 +297,20 @@ class Qwen2Model(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None else: - hidden_states = input_embeds - residual = None - for i in range(len(self.layers)): + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, @@ -294,7 +318,15 @@ class Qwen2Model(nn.Module): forward_batch, residual, ) - hidden_states, _ = self.norm(hidden_states, residual) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states # If this function is called, it should always initialize KV cache scale @@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = Qwen2Model( @@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module): forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - if not get_embedding: - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch - ) + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) else: - return self.pooler(hidden_states, forward_batch) + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if "rotary_emb.inv_freq" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: @@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0855dd8ae..525498d5b 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -16,9 +16,10 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" +import logging from dataclasses import dataclass from enum import Enum, auto -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -26,6 +27,7 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, make_layers expert_distribution_recorder = ExpertDistributionRecorder() +logger = logging.getLogger(__name__) + class Qwen2MoeMLP(nn.Module): def __init__( @@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], - prefix=add_prefix("embed_tokens", prefix), - ) # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer - self.layers = make_layers( + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, lambda idx, prefix: decoder_layer_type( layer_id=idx, @@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module): quant_config=quant_config, prefix=prefix, ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, prefix=add_prefix("layers", prefix), ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) def forward( self, @@ -562,20 +577,35 @@ class Qwen2MoeModel(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None else: - hidden_states = input_embeds - residual = None - for i in range(len(self.layers)): + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): expert_distribution_recorder.set_current_layer(i) layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual ) - if hidden_states.shape[0] != 0: - hidden_states, _ = self.norm(hidden_states, residual) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if hidden_states.shape[0] != 0: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = Qwen2MoeModel( @@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> LogitsProcessorOutput: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, ) + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module): if name not in params_dict: continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 016f4b5de..181802a09 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -1,5 +1,6 @@ # Adapted from qwen2.py +import logging from functools import partial from typing import Any, Dict, Iterable, Optional, Tuple @@ -7,6 +8,7 @@ import torch from torch import nn from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model @@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix Qwen3Config = None +logger = logging.getLogger(__name__) + class Qwen3Attention(nn.Module): def __init__( @@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = Qwen3Model( @@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module): forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - if not get_embedding: - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch - ) + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) else: - return self.pooler(hidden_states, forward_batch) + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if "rotary_emb.inv_freq" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: @@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 7f841bf37..d553395f2 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -17,6 +17,7 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +import logging from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -28,6 +29,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel @@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix Qwen3MoeConfig = None +logger = logging.getLogger(__name__) + class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( @@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = Qwen3MoeModel( @@ -536,12 +542,31 @@ class Qwen3MoeForCausalLM(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> LogitsProcessorOutput: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, ) + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module): if name not in params_dict: continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") EntryClass = Qwen3MoeForCausalLM diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py index 8a4ffaa62..3f95271ee 100644 --- a/test/srt/test_pp_single_node.py +++ b/test/srt/test_pp_single_node.py @@ -1,6 +1,7 @@ """ Usage: python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k +python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs """ @@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase): time.sleep(5) +class TestQwenPPAccuracy(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts + cls.model_name = "Qwen/Qwen3-8B" # replace with your Qwen Model if needed + + def run_gsm8k_test(self, pp_size): + process = popen_launch_server( + self.model_name, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--pp-size", + pp_size, + "--chunked-prefill-size", + 256, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + time.sleep(5) + return metrics + finally: + kill_process_tree(process.pid) + + def test_baseline_accuracy(self): + metrics = self.run_gsm8k_test(pp_size=1) + print(f"[Qwen Baseline] {metrics=}") + self.assertGreater(metrics["accuracy"], 0.74) + + def test_pp_consistency(self): + baseline = self.run_gsm8k_test(pp_size=1) + pp_metrics = self.run_gsm8k_test(pp_size=2) + + print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") + + self.assertAlmostEqual( + pp_metrics["accuracy"], + baseline["accuracy"], + delta=0.01, + msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})", + ) + + class TestFixedBugs(unittest.TestCase): def test_chunked_prefill_with_small_bs(self): model = DEFAULT_MODEL_NAME_FOR_TEST