diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f4ffddb6b..0d55b63eb 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -29,6 +29,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[dev]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Run test @@ -48,6 +49,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[dev]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Run test @@ -67,6 +69,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[dev]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Run test @@ -86,6 +89,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[dev]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Run test @@ -105,6 +109,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[all]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Benchmark Single Latency @@ -136,6 +141,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[all]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Benchmark Offline Throughput (w/o RadixAttention) @@ -167,6 +173,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[all]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Benchmark Offline Throughput (TP=2) @@ -198,6 +205,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[all]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall git clone https://github.com/merrymercy/human-eval.git @@ -221,6 +229,7 @@ jobs: run: | pip install --upgrade pip pip install -e "python[all]" + pip install transformers==4.44 pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall git clone https://github.com/merrymercy/human-eval.git diff --git a/examples/runtime/reward_model.py b/examples/runtime/reward_model.py new file mode 100644 index 000000000..3b63c8dd3 --- /dev/null +++ b/examples/runtime/reward_model.py @@ -0,0 +1,34 @@ +# launch server +# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding + +import json + +import requests + +url = "http://127.0.0.1:30000" + +PROMPT = ( + "What is the range of the numeric output of a sigmoid node in a neural network?" +) +RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1." +RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1." + +json_data = { + "conv": [ + [ + {"role": "user", "content": PROMPT}, + {"role": "assistant", "content": RESPONSE1}, + ], + [ + {"role": "user", "content": PROMPT}, + {"role": "assistant", "content": RESPONSE2}, + ], + ], +} +response = requests.post( + url + "/judge", + json=json_data, +).json() + +print(response) +print("scores:", [x["embedding"] for x in response]) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1c6994d77..0c7a57f46 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -215,12 +215,11 @@ class EmbeddingReqInput: raise ValueError("Either text or input_ids should be provided.") if self.text is not None: - is_single = isinstance(self.text, str) + self.is_single = isinstance(self.text, str) else: - is_single = isinstance(self.input_ids[0], int) - self.is_single = is_single + self.is_single = isinstance(self.input_ids[0], int) - if is_single: + if self.is_single: if self.rid is None: self.rid = uuid.uuid4().hex if self.sampling_params is None: @@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput: sampling_params: SamplingParams +@dataclass +class RewardReqInput: + # The input prompt in the chat format. It can be a single prompt or a batch of prompts. + conv: Union[List[List[Dict]], List[Dict]] + # The request id. + rid: Optional[Union[List[str], str]] = None + # Dummy sampling params for compatibility + sampling_params: Union[List[Dict], Dict] = None + + is_single: bool = True + + def post_init(self): + self.is_single = isinstance(self.conv[0], dict) + + if self.is_single: + if self.rid is None: + self.rid = uuid.uuid4().hex + if self.sampling_params is None: + self.sampling_params = {} + self.sampling_params["max_new_tokens"] = 1 + else: + # support select operation + self.batch_size = len(self.conv) + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + if self.sampling_params is None: + self.sampling_params = [{}] * self.batch_size + for i in range(self.batch_size): + self.sampling_params[i]["max_new_tokens"] = 1 + + +@dataclass +class TokenizedRewardReqInput: + # The request id + rid: str + # The input text + input_text: str + # The input token ids + input_ids: List[int] + # Dummy sampling params for compatibility + sampling_params: SamplingParams + + @dataclass class BatchTokenIDOut: # The request id diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e2fa246bb..b93ceb3a6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + RewardReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + TokenizedRewardReqInput, UpdateWeightReqInput, UpdateWeightReqOutput, ) @@ -142,7 +144,7 @@ class TokenizerManager: async def generate_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], request: Optional[fastapi.Request] = None, ): if self.to_create_loop: @@ -163,7 +165,7 @@ class TokenizerManager: async def _handle_single_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], request: Optional[fastapi.Request] = None, index: Optional[int] = None, is_cache_for_prefill: Optional[bool] = False, @@ -173,7 +175,13 @@ class TokenizerManager: rid = obj.rid if not_use_index else obj.rid[index] input_text = obj.text if not_use_index else obj.text[index] - if obj.input_ids is None: + if hasattr(obj, "conv"): + # reward model + assert self.tokenizer is not None + conv = obj.conv if not_use_index else obj.conv[index] + input_text = self.tokenizer.apply_chat_template(conv, tokenize=False) + input_ids = self.tokenizer.encode(input_text) + elif obj.input_ids is None: assert self.tokenizer is not None input_ids = self.tokenizer.encode(input_text) else: @@ -269,13 +277,21 @@ class TokenizerManager: else obj.lora_path ), ) - else: # is embedding + elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( rid, input_text, input_ids, sampling_params, ) + else: + assert isinstance(obj, RewardReqInput) + tokenized_obj = TokenizedRewardReqInput( + rid, + input_text, + input_ids, + sampling_params, + ) self.send_to_controller.send_pyobj(tokenized_obj) # Recv results @@ -292,7 +308,7 @@ class TokenizerManager: async def _handle_batch_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], request: Optional[fastapi.Request] = None, ): batch_size = obj.batch_size @@ -329,9 +345,16 @@ class TokenizerManager: rid = obj.rid[index] if parallel_sample_num == 1: ## select operation - if obj.input_ids is None: + if hasattr(obj, "conv"): + # reward model + conv = obj.conv[i] + input_text = self.tokenizer.apply_chat_template( + conv, tokenize=False + ) + input_ids = self.tokenizer.encode(input_text) + elif obj.input_ids is None: input_text = obj.text[i] - input_ids = self.tokenizer.encode(obj.text[i]) + input_ids = self.tokenizer.encode(input_text) else: input_text = None input_ids = obj.input_ids[i] @@ -370,13 +393,21 @@ class TokenizerManager: else obj.lora_path ), ) - else: + elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( rid, input_text, input_ids, sampling_params, ) + else: + assert isinstance(obj, RewardReqInput) + tokenized_obj = TokenizedRewardReqInput( + rid, + input_text, + input_ids, + sampling_params, + ) self.send_to_controller.send_pyobj(tokenized_obj) event = asyncio.Event() @@ -442,7 +473,7 @@ class TokenizerManager: async def _wait_for_response( self, state: ReqState, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], rid: str, request: Optional[fastapi.Request] = None, index: Optional[int] = None, @@ -469,7 +500,7 @@ class TokenizerManager: ), obj.return_text_in_logprobs, ) - else: # isinstance(obj, EmbeddingReqInput) + else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput)) out = state.out_list[-1] out["index"] = response_index diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 414424e5b..b96906700 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -22,7 +22,7 @@ import os import pickle import time import warnings -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed @@ -41,6 +41,7 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + TokenizedRewardReqInput, UpdateWeightReqInput, UpdateWeightReqOutput, ) @@ -223,7 +224,9 @@ class ModelTpServer: if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) self.do_not_get_new_batch = False - elif isinstance(recv_req, TokenizedEmbeddingReqInput): + elif isinstance( + recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput) + ): self.handle_embedding_request(recv_req) self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): @@ -407,7 +410,7 @@ class ModelTpServer: def handle_embedding_request( self, - recv_req: TokenizedEmbeddingReqInput, + recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput], ): req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py new file mode 100644 index 000000000..519d9a0d2 --- /dev/null +++ b/python/sglang/srt/models/llama_reward.py @@ -0,0 +1,142 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig +from vllm.config import CacheConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel + + +class LlamaForSequenceClassification(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.torchao_config = None + self.quant_config = quant_config + self.num_labels = config.num_labels + self.model = LlamaModel(config, quant_config=quant_config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) + + self.eos_token_id = config.eos_token_id + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> EmbeddingPoolerOutput: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + scores = self.score(hidden_states) + + return self.pooler(scores, input_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "classification_head" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + elif "lm_head" in name: + continue + else: + LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) + + +class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification): + class Weights(torch.nn.Module): + def __init__(self, hidden_size, num_label): + super().__init__() + self.fc = torch.nn.Sequential( + torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16), + torch.nn.SELU(), + torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16), + torch.nn.SELU(), + torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16), + ) + + def forward(self, x): + return self.fc(x.to(torch.float16)) + + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__(config, quant_config, cache_config) + self.weights = self.Weights(config.hidden_size, self.num_labels) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + get_embedding: bool = True, + ) -> EmbeddingPoolerOutput: + assert ( + get_embedding + ), "LlamaForSequenceClassification is only used for embedding" + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + logits = self.score(hidden_states) + weights = self.weights(hidden_states) + + pooled_logits = self.pooler(logits, input_metadata).embeddings + pooled_weights = self.pooler(weights, input_metadata).embeddings + + rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view( + -1, self.num_labels // 2 + ) + scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1) + return EmbeddingPoolerOutput(scores) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "classification_head" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + elif "lm_head" in name: + continue + else: + LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) + + +EntryClass = [ + LlamaForSequenceClassification, + LlamaForSequenceClassificationWithNormal_Weights, +] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9379b9d08..495319f3e 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, + RewardReqInput, UpdateWeightReqInput, ) from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -213,6 +214,21 @@ app.post("/encode")(encode_request) app.put("/encode")(encode_request) +async def judge_request(obj: RewardReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +app.post("/judge")(judge_request) +app.put("/judge")(judge_request) + + @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): return await v1_completions(tokenizer_manager, raw_request) @@ -635,15 +651,26 @@ class Runtime: def encode( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], ): - json_data = { - "text": prompt, - } - response = requests.post( - self.url + "/encode", - json=json_data, - ) + if isinstance(prompt, str) or isinstance(prompt[0], str): + # embedding + json_data = { + "text": prompt, + } + response = requests.post( + self.url + "/encode", + json=json_data, + ) + else: + # reward + json_data = { + "conv": prompt, + } + response = requests.post( + self.url + "/judge", + json=json_data, + ) return json.dumps(response.json()) def __del__(self): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c2d2b5f11..126118406 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -219,6 +219,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False): if ( "LlamaEmbeddingModel" in model_architectures or "MistralModel" in model_architectures + or "LlamaForSequenceClassification" in model_architectures + or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures ): return False else: diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 60790a31e..023ff8929 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -65,6 +65,7 @@ class ModelOutput: top_input_logprobs: List[torch.Tensor] = None top_output_logprobs: List[torch.Tensor] = None embed_logits: List[torch.Tensor] = None + scores: List[float] = None class HFRunner: @@ -72,10 +73,10 @@ class HFRunner: self, model_path, torch_dtype, - is_generation, + model_type="generation", output_str_only=False, ): - self.is_generation = is_generation + self.model_type = model_type self.output_str_only = output_str_only self.in_queue = mp.Queue() @@ -92,22 +93,41 @@ class HFRunner: ) self.model_proc.start() + def needs_trust_remote_code(self, model_path): + models_needs_trust_remote = [ + "LxzGordon/URM-LLaMa-3.1-8B", + ] + if model_path in models_needs_trust_remote: + return True + return False + def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): - self.tokenizer = get_tokenizer(model_path) - if self.is_generation: + self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype) + + if self.model_type == "generation": self.base_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=False, low_cpu_mem_usage=True, ).cuda() - else: + elif self.model_type == "embedding": from sentence_transformers import SentenceTransformer self.model = SentenceTransformer( model_path, model_kwargs={"torch_dtype": torch_dtype}, - ) + ).cuda() + elif self.model_type == "reward": + from transformers import AutoModelForSequenceClassification + + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=self.needs_trust_remote_code(model_path), + ).cuda() + else: + raise Exception(f"Unrecognized model type {self.model_type}") while True: prompts, max_new_tokens, lora_paths = in_queue.get() @@ -115,7 +135,7 @@ class HFRunner: assert len(prompts) == len(lora_paths) if prompts is not None: - if self.is_generation: + if self.model_type == "generation": output_strs = [] top_input_logprobs = [] top_output_logprobs = [] @@ -179,11 +199,27 @@ class HFRunner: ) ) - else: + elif self.model_type == "embedding": assert not self.output_str_only logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) + elif self.model_type == "reward": + scores = [] + for conv in prompts: + conv_formatted = self.tokenizer.apply_chat_template( + conv, tokenize=False + ) + conv_tokenized = self.tokenizer( + conv_formatted, return_tensors="pt" + ).to("cuda") + scores.append( + float(self.model(**conv_tokenized).logits[0][0].item()) + ) + out_queue.put(ModelOutput(scores=scores)) + else: + raise Exception(f"Unrecognized model type {self.model_type}") + def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, @@ -210,7 +246,7 @@ class SRTRunner: self, model_path, torch_dtype, - is_generation, + model_type, tp_size=1, port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths=None, @@ -218,13 +254,14 @@ class SRTRunner: disable_cuda_graph=False, disable_radix_cache=False, ): - self.is_generation = is_generation + self.model_type = model_type + self.is_generation = model_type == "generation" self.runtime = Runtime( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, - mem_fraction_static=0.69, + mem_fraction_static=0.65, trust_remote_code=False, is_embedding=not self.is_generation, lora_paths=lora_paths, @@ -285,8 +322,12 @@ class SRTRunner: else: response = self.runtime.encode(prompts) response = json.loads(response) - logits = [x["embedding"] for x in response] - return ModelOutput(embed_logits=logits) + if self.model_type == "embedding": + logits = [x["embedding"] for x in response] + return ModelOutput(embed_logits=logits) + else: + scores = [x["embedding"][0] for x in response] + return ModelOutput(scores=scores) def batch_forward( self, @@ -316,8 +357,12 @@ class SRTRunner: else: response = self.runtime.encode(prompts) response = json.loads(response) - logits = [x["embedding"] for x in response] - return ModelOutput(embed_logits=logits) + if self.model_type == "embedding": + logits = [x["embedding"] for x in response] + return ModelOutput(embed_logits=logits) + else: + scores = [x["embedding"][0] for x in response] + return ModelOutput(scores=logits) def __enter__(self): return self diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index a5a73bf31..3ad187cbb 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -39,7 +39,9 @@ class TestEmbeddingModels(unittest.TestCase): prefill_tolerance, ) -> None: with HFRunner( - model_path, torch_dtype=torch_dtype, is_generation=False + model_path, + torch_dtype=torch_dtype, + model_type="embedding", ) as hf_runner: hf_outputs = hf_runner.forward(prompts) @@ -47,7 +49,7 @@ class TestEmbeddingModels(unittest.TestCase): model_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation=False, + model_type="embedding", ) as srt_runner: srt_outputs = srt_runner.forward(prompts) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 732b3d800..21078e8aa 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -73,7 +73,9 @@ class TestGenerationModels(unittest.TestCase): max_new_tokens = 32 with HFRunner( - model_path, torch_dtype=torch_dtype, is_generation=True + model_path, + torch_dtype=torch_dtype, + model_type="generation", ) as hf_runner: hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) @@ -81,7 +83,7 @@ class TestGenerationModels(unittest.TestCase): model_path, tp_size=model_case.tp_size, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", ) as srt_runner: srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py new file mode 100644 index 000000000..cd15b4967 --- /dev/null +++ b/test/srt/models/test_reward_models.py @@ -0,0 +1,91 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import multiprocessing as mp +import unittest + +import torch + +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner + +MODELS = [ + ("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2), +] +TORCH_DTYPES = [torch.float16] + +# PROMPT = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?" +# RESPONSE1 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among herself and her 2 siblings (3 people in total). 9 ÷ 3 = 3 apples each. Each person gets 3 apples." +# RESPONSE2 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among her 2 siblings (2 people in total). 9 ÷ 2 = 4.5 apples each. Each person gets 4 apples." + +PROMPT = ( + "What is the range of the numeric output of a sigmoid node in a neural network?" +) +RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1." +RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1." + +CONVS = [ + [{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}], + [{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}], +] + + +class TestRewardModels(unittest.TestCase): + + def assert_close_reward_scores( + self, + convs, + model_path, + tp_size, + torch_dtype, + tolerance, + ) -> None: + with HFRunner( + model_path, + torch_dtype=torch_dtype, + model_type="reward", + ) as hf_runner: + hf_outputs = hf_runner.forward(convs) + + with SRTRunner( + model_path, + torch_dtype=torch_dtype, + model_type="reward", + ) as srt_runner: + srt_outputs = srt_runner.forward(convs) + + hf_scores = torch.tensor(hf_outputs.scores) + srt_scores = torch.tensor(srt_outputs.scores) + print(hf_scores) + print(srt_scores) + + assert torch.all( + abs(hf_scores - srt_scores) < tolerance + ), "reward scores are not all close" + + def test_reward_scores(self): + for model, tp_size, tolerance in MODELS: + for torch_dtype in TORCH_DTYPES: + self.assert_close_reward_scores( + CONVS, model, tp_size, torch_dtype, tolerance + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 9210be948..b7b81f9dd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -7,6 +7,7 @@ suites = { "minimal": [ "models/test_embedding_models.py", "models/test_generation_models.py", + "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py", "test_embedding_openai_server.py",