[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)
This commit is contained in:
9
.github/workflows/pr-test.yml
vendored
9
.github/workflows/pr-test.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[dev]"
|
pip install -e "python[dev]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
@@ -48,6 +49,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[dev]"
|
pip install -e "python[dev]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
@@ -67,6 +69,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[dev]"
|
pip install -e "python[dev]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
@@ -86,6 +89,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[dev]"
|
pip install -e "python[dev]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
@@ -105,6 +109,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Benchmark Single Latency
|
- name: Benchmark Single Latency
|
||||||
@@ -136,6 +141,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Benchmark Offline Throughput (w/o RadixAttention)
|
- name: Benchmark Offline Throughput (w/o RadixAttention)
|
||||||
@@ -167,6 +173,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Benchmark Offline Throughput (TP=2)
|
- name: Benchmark Offline Throughput (TP=2)
|
||||||
@@ -198,6 +205,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
git clone https://github.com/merrymercy/human-eval.git
|
git clone https://github.com/merrymercy/human-eval.git
|
||||||
@@ -221,6 +229,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
|
pip install transformers==4.44
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
git clone https://github.com/merrymercy/human-eval.git
|
git clone https://github.com/merrymercy/human-eval.git
|
||||||
|
|||||||
34
examples/runtime/reward_model.py
Normal file
34
examples/runtime/reward_model.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# launch server
|
||||||
|
# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
url = "http://127.0.0.1:30000"
|
||||||
|
|
||||||
|
PROMPT = (
|
||||||
|
"What is the range of the numeric output of a sigmoid node in a neural network?"
|
||||||
|
)
|
||||||
|
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
|
||||||
|
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"conv": [
|
||||||
|
[
|
||||||
|
{"role": "user", "content": PROMPT},
|
||||||
|
{"role": "assistant", "content": RESPONSE1},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": PROMPT},
|
||||||
|
{"role": "assistant", "content": RESPONSE2},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
url + "/judge",
|
||||||
|
json=json_data,
|
||||||
|
).json()
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
print("scores:", [x["embedding"] for x in response])
|
||||||
@@ -215,12 +215,11 @@ class EmbeddingReqInput:
|
|||||||
raise ValueError("Either text or input_ids should be provided.")
|
raise ValueError("Either text or input_ids should be provided.")
|
||||||
|
|
||||||
if self.text is not None:
|
if self.text is not None:
|
||||||
is_single = isinstance(self.text, str)
|
self.is_single = isinstance(self.text, str)
|
||||||
else:
|
else:
|
||||||
is_single = isinstance(self.input_ids[0], int)
|
self.is_single = isinstance(self.input_ids[0], int)
|
||||||
self.is_single = is_single
|
|
||||||
|
|
||||||
if is_single:
|
if self.is_single:
|
||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
|
|||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RewardReqInput:
|
||||||
|
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
||||||
|
conv: Union[List[List[Dict]], List[Dict]]
|
||||||
|
# The request id.
|
||||||
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Dummy sampling params for compatibility
|
||||||
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
|
|
||||||
|
is_single: bool = True
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
self.is_single = isinstance(self.conv[0], dict)
|
||||||
|
|
||||||
|
if self.is_single:
|
||||||
|
if self.rid is None:
|
||||||
|
self.rid = uuid.uuid4().hex
|
||||||
|
if self.sampling_params is None:
|
||||||
|
self.sampling_params = {}
|
||||||
|
self.sampling_params["max_new_tokens"] = 1
|
||||||
|
else:
|
||||||
|
# support select operation
|
||||||
|
self.batch_size = len(self.conv)
|
||||||
|
if self.rid is None:
|
||||||
|
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
||||||
|
else:
|
||||||
|
if not isinstance(self.rid, list):
|
||||||
|
raise ValueError("The rid should be a list.")
|
||||||
|
if self.sampling_params is None:
|
||||||
|
self.sampling_params = [{}] * self.batch_size
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
self.sampling_params[i]["max_new_tokens"] = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenizedRewardReqInput:
|
||||||
|
# The request id
|
||||||
|
rid: str
|
||||||
|
# The input text
|
||||||
|
input_text: str
|
||||||
|
# The input token ids
|
||||||
|
input_ids: List[int]
|
||||||
|
# Dummy sampling params for compatibility
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
# The request id
|
# The request id
|
||||||
|
|||||||
@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
RewardReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedRewardReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightReqInput,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
@@ -142,7 +144,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
@@ -163,7 +165,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def _handle_single_request(
|
async def _handle_single_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
index: Optional[int] = None,
|
index: Optional[int] = None,
|
||||||
is_cache_for_prefill: Optional[bool] = False,
|
is_cache_for_prefill: Optional[bool] = False,
|
||||||
@@ -173,7 +175,13 @@ class TokenizerManager:
|
|||||||
|
|
||||||
rid = obj.rid if not_use_index else obj.rid[index]
|
rid = obj.rid if not_use_index else obj.rid[index]
|
||||||
input_text = obj.text if not_use_index else obj.text[index]
|
input_text = obj.text if not_use_index else obj.text[index]
|
||||||
if obj.input_ids is None:
|
if hasattr(obj, "conv"):
|
||||||
|
# reward model
|
||||||
|
assert self.tokenizer is not None
|
||||||
|
conv = obj.conv if not_use_index else obj.conv[index]
|
||||||
|
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
elif obj.input_ids is None:
|
||||||
assert self.tokenizer is not None
|
assert self.tokenizer is not None
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
else:
|
else:
|
||||||
@@ -269,13 +277,21 @@ class TokenizerManager:
|
|||||||
else obj.lora_path
|
else obj.lora_path
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else: # is embedding
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
rid,
|
rid,
|
||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(obj, RewardReqInput)
|
||||||
|
tokenized_obj = TokenizedRewardReqInput(
|
||||||
|
rid,
|
||||||
|
input_text,
|
||||||
|
input_ids,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
# Recv results
|
# Recv results
|
||||||
@@ -292,7 +308,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def _handle_batch_request(
|
async def _handle_batch_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
batch_size = obj.batch_size
|
batch_size = obj.batch_size
|
||||||
@@ -329,9 +345,16 @@ class TokenizerManager:
|
|||||||
rid = obj.rid[index]
|
rid = obj.rid[index]
|
||||||
if parallel_sample_num == 1:
|
if parallel_sample_num == 1:
|
||||||
## select operation
|
## select operation
|
||||||
if obj.input_ids is None:
|
if hasattr(obj, "conv"):
|
||||||
|
# reward model
|
||||||
|
conv = obj.conv[i]
|
||||||
|
input_text = self.tokenizer.apply_chat_template(
|
||||||
|
conv, tokenize=False
|
||||||
|
)
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
elif obj.input_ids is None:
|
||||||
input_text = obj.text[i]
|
input_text = obj.text[i]
|
||||||
input_ids = self.tokenizer.encode(obj.text[i])
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
else:
|
else:
|
||||||
input_text = None
|
input_text = None
|
||||||
input_ids = obj.input_ids[i]
|
input_ids = obj.input_ids[i]
|
||||||
@@ -370,13 +393,21 @@ class TokenizerManager:
|
|||||||
else obj.lora_path
|
else obj.lora_path
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
rid,
|
rid,
|
||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(obj, RewardReqInput)
|
||||||
|
tokenized_obj = TokenizedRewardReqInput(
|
||||||
|
rid,
|
||||||
|
input_text,
|
||||||
|
input_ids,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
@@ -442,7 +473,7 @@ class TokenizerManager:
|
|||||||
async def _wait_for_response(
|
async def _wait_for_response(
|
||||||
self,
|
self,
|
||||||
state: ReqState,
|
state: ReqState,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||||
rid: str,
|
rid: str,
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
index: Optional[int] = None,
|
index: Optional[int] = None,
|
||||||
@@ -469,7 +500,7 @@ class TokenizerManager:
|
|||||||
),
|
),
|
||||||
obj.return_text_in_logprobs,
|
obj.return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
else: # isinstance(obj, EmbeddingReqInput)
|
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
|
||||||
out = state.out_list[-1]
|
out = state.out_list[-1]
|
||||||
|
|
||||||
out["index"] = response_index
|
out["index"] = response_index
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@@ -41,6 +41,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedRewardReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightReqInput,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
@@ -223,7 +224,9 @@ class ModelTpServer:
|
|||||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
self.do_not_get_new_batch = False
|
self.do_not_get_new_batch = False
|
||||||
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
elif isinstance(
|
||||||
|
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
||||||
|
):
|
||||||
self.handle_embedding_request(recv_req)
|
self.handle_embedding_request(recv_req)
|
||||||
self.do_not_get_new_batch = False
|
self.do_not_get_new_batch = False
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
@@ -407,7 +410,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def handle_embedding_request(
|
def handle_embedding_request(
|
||||||
self,
|
self,
|
||||||
recv_req: TokenizedEmbeddingReqInput,
|
recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
|
||||||
):
|
):
|
||||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|||||||
142
python/sglang/srt/models/llama_reward.py
Normal file
142
python/sglang/srt/models/llama_reward.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForSequenceClassification(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.torchao_config = None
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||||
|
|
||||||
|
self.eos_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
) -> EmbeddingPoolerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
|
scores = self.score(hidden_states)
|
||||||
|
|
||||||
|
return self.pooler(scores, input_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "classification_head" in name:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
elif "lm_head" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
|
||||||
|
class Weights(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_label):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
|
||||||
|
torch.nn.SELU(),
|
||||||
|
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
|
||||||
|
torch.nn.SELU(),
|
||||||
|
torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc(x.to(torch.float16))
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, quant_config, cache_config)
|
||||||
|
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
get_embedding: bool = True,
|
||||||
|
) -> EmbeddingPoolerOutput:
|
||||||
|
assert (
|
||||||
|
get_embedding
|
||||||
|
), "LlamaForSequenceClassification is only used for embedding"
|
||||||
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
|
logits = self.score(hidden_states)
|
||||||
|
weights = self.weights(hidden_states)
|
||||||
|
|
||||||
|
pooled_logits = self.pooler(logits, input_metadata).embeddings
|
||||||
|
pooled_weights = self.pooler(weights, input_metadata).embeddings
|
||||||
|
|
||||||
|
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
|
||||||
|
-1, self.num_labels // 2
|
||||||
|
)
|
||||||
|
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
|
||||||
|
return EmbeddingPoolerOutput(scores)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "classification_head" in name:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
elif "lm_head" in name:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [
|
||||||
|
LlamaForSequenceClassification,
|
||||||
|
LlamaForSequenceClassificationWithNormal_Weights,
|
||||||
|
]
|
||||||
@@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
RewardReqInput,
|
||||||
UpdateWeightReqInput,
|
UpdateWeightReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -213,6 +214,21 @@ app.post("/encode")(encode_request)
|
|||||||
app.put("/encode")(encode_request)
|
app.put("/encode")(encode_request)
|
||||||
|
|
||||||
|
|
||||||
|
async def judge_request(obj: RewardReqInput, request: Request):
|
||||||
|
"""Handle an embedding request."""
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.post("/judge")(judge_request)
|
||||||
|
app.put("/judge")(judge_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def openai_v1_completions(raw_request: Request):
|
async def openai_v1_completions(raw_request: Request):
|
||||||
return await v1_completions(tokenizer_manager, raw_request)
|
return await v1_completions(tokenizer_manager, raw_request)
|
||||||
@@ -635,15 +651,26 @@ class Runtime:
|
|||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
):
|
):
|
||||||
json_data = {
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||||
"text": prompt,
|
# embedding
|
||||||
}
|
json_data = {
|
||||||
response = requests.post(
|
"text": prompt,
|
||||||
self.url + "/encode",
|
}
|
||||||
json=json_data,
|
response = requests.post(
|
||||||
)
|
self.url + "/encode",
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# reward
|
||||||
|
json_data = {
|
||||||
|
"conv": prompt,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
self.url + "/judge",
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
return json.dumps(response.json())
|
return json.dumps(response.json())
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@@ -219,6 +219,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
|
|||||||
if (
|
if (
|
||||||
"LlamaEmbeddingModel" in model_architectures
|
"LlamaEmbeddingModel" in model_architectures
|
||||||
or "MistralModel" in model_architectures
|
or "MistralModel" in model_architectures
|
||||||
|
or "LlamaForSequenceClassification" in model_architectures
|
||||||
|
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class ModelOutput:
|
|||||||
top_input_logprobs: List[torch.Tensor] = None
|
top_input_logprobs: List[torch.Tensor] = None
|
||||||
top_output_logprobs: List[torch.Tensor] = None
|
top_output_logprobs: List[torch.Tensor] = None
|
||||||
embed_logits: List[torch.Tensor] = None
|
embed_logits: List[torch.Tensor] = None
|
||||||
|
scores: List[float] = None
|
||||||
|
|
||||||
|
|
||||||
class HFRunner:
|
class HFRunner:
|
||||||
@@ -72,10 +73,10 @@ class HFRunner:
|
|||||||
self,
|
self,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
is_generation,
|
model_type="generation",
|
||||||
output_str_only=False,
|
output_str_only=False,
|
||||||
):
|
):
|
||||||
self.is_generation = is_generation
|
self.model_type = model_type
|
||||||
self.output_str_only = output_str_only
|
self.output_str_only = output_str_only
|
||||||
|
|
||||||
self.in_queue = mp.Queue()
|
self.in_queue = mp.Queue()
|
||||||
@@ -92,22 +93,41 @@ class HFRunner:
|
|||||||
)
|
)
|
||||||
self.model_proc.start()
|
self.model_proc.start()
|
||||||
|
|
||||||
|
def needs_trust_remote_code(self, model_path):
|
||||||
|
models_needs_trust_remote = [
|
||||||
|
"LxzGordon/URM-LLaMa-3.1-8B",
|
||||||
|
]
|
||||||
|
if model_path in models_needs_trust_remote:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||||
self.tokenizer = get_tokenizer(model_path)
|
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||||
if self.is_generation:
|
|
||||||
|
if self.model_type == "generation":
|
||||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
).cuda()
|
).cuda()
|
||||||
else:
|
elif self.model_type == "embedding":
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
self.model = SentenceTransformer(
|
self.model = SentenceTransformer(
|
||||||
model_path,
|
model_path,
|
||||||
model_kwargs={"torch_dtype": torch_dtype},
|
model_kwargs={"torch_dtype": torch_dtype},
|
||||||
)
|
).cuda()
|
||||||
|
elif self.model_type == "reward":
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=self.needs_trust_remote_code(model_path),
|
||||||
|
).cuda()
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||||
@@ -115,7 +135,7 @@ class HFRunner:
|
|||||||
assert len(prompts) == len(lora_paths)
|
assert len(prompts) == len(lora_paths)
|
||||||
|
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
if self.is_generation:
|
if self.model_type == "generation":
|
||||||
output_strs = []
|
output_strs = []
|
||||||
top_input_logprobs = []
|
top_input_logprobs = []
|
||||||
top_output_logprobs = []
|
top_output_logprobs = []
|
||||||
@@ -179,11 +199,27 @@ class HFRunner:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
elif self.model_type == "embedding":
|
||||||
assert not self.output_str_only
|
assert not self.output_str_only
|
||||||
logits = self.model.encode(prompts).tolist()
|
logits = self.model.encode(prompts).tolist()
|
||||||
out_queue.put(ModelOutput(embed_logits=logits))
|
out_queue.put(ModelOutput(embed_logits=logits))
|
||||||
|
|
||||||
|
elif self.model_type == "reward":
|
||||||
|
scores = []
|
||||||
|
for conv in prompts:
|
||||||
|
conv_formatted = self.tokenizer.apply_chat_template(
|
||||||
|
conv, tokenize=False
|
||||||
|
)
|
||||||
|
conv_tokenized = self.tokenizer(
|
||||||
|
conv_formatted, return_tensors="pt"
|
||||||
|
).to("cuda")
|
||||||
|
scores.append(
|
||||||
|
float(self.model(**conv_tokenized).logits[0][0].item())
|
||||||
|
)
|
||||||
|
out_queue.put(ModelOutput(scores=scores))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
@@ -210,7 +246,7 @@ class SRTRunner:
|
|||||||
self,
|
self,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
is_generation,
|
model_type,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||||
lora_paths=None,
|
lora_paths=None,
|
||||||
@@ -218,13 +254,14 @@ class SRTRunner:
|
|||||||
disable_cuda_graph=False,
|
disable_cuda_graph=False,
|
||||||
disable_radix_cache=False,
|
disable_radix_cache=False,
|
||||||
):
|
):
|
||||||
self.is_generation = is_generation
|
self.model_type = model_type
|
||||||
|
self.is_generation = model_type == "generation"
|
||||||
self.runtime = Runtime(
|
self.runtime = Runtime(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
dtype=get_dtype_str(torch_dtype),
|
dtype=get_dtype_str(torch_dtype),
|
||||||
port=port,
|
port=port,
|
||||||
mem_fraction_static=0.69,
|
mem_fraction_static=0.65,
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
is_embedding=not self.is_generation,
|
is_embedding=not self.is_generation,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
@@ -285,8 +322,12 @@ class SRTRunner:
|
|||||||
else:
|
else:
|
||||||
response = self.runtime.encode(prompts)
|
response = self.runtime.encode(prompts)
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
logits = [x["embedding"] for x in response]
|
if self.model_type == "embedding":
|
||||||
return ModelOutput(embed_logits=logits)
|
logits = [x["embedding"] for x in response]
|
||||||
|
return ModelOutput(embed_logits=logits)
|
||||||
|
else:
|
||||||
|
scores = [x["embedding"][0] for x in response]
|
||||||
|
return ModelOutput(scores=scores)
|
||||||
|
|
||||||
def batch_forward(
|
def batch_forward(
|
||||||
self,
|
self,
|
||||||
@@ -316,8 +357,12 @@ class SRTRunner:
|
|||||||
else:
|
else:
|
||||||
response = self.runtime.encode(prompts)
|
response = self.runtime.encode(prompts)
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
logits = [x["embedding"] for x in response]
|
if self.model_type == "embedding":
|
||||||
return ModelOutput(embed_logits=logits)
|
logits = [x["embedding"] for x in response]
|
||||||
|
return ModelOutput(embed_logits=logits)
|
||||||
|
else:
|
||||||
|
scores = [x["embedding"][0] for x in response]
|
||||||
|
return ModelOutput(scores=logits)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -39,7 +39,9 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
prefill_tolerance,
|
prefill_tolerance,
|
||||||
) -> None:
|
) -> None:
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_generation=False
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
model_type="embedding",
|
||||||
) as hf_runner:
|
) as hf_runner:
|
||||||
hf_outputs = hf_runner.forward(prompts)
|
hf_outputs = hf_runner.forward(prompts)
|
||||||
|
|
||||||
@@ -47,7 +49,7 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_generation=False,
|
model_type="embedding",
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(prompts)
|
srt_outputs = srt_runner.forward(prompts)
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,9 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
max_new_tokens = 32
|
max_new_tokens = 32
|
||||||
|
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_generation=True
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
model_type="generation",
|
||||||
) as hf_runner:
|
) as hf_runner:
|
||||||
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
@@ -81,7 +83,7 @@ class TestGenerationModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size=model_case.tp_size,
|
tp_size=model_case.tp_size,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_generation=True,
|
model_type="generation",
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
|
|||||||
91
test/srt/models/test_reward_models.py
Normal file
91
test/srt/models/test_reward_models.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2),
|
||||||
|
]
|
||||||
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
# PROMPT = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
|
||||||
|
# RESPONSE1 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among herself and her 2 siblings (3 people in total). 9 ÷ 3 = 3 apples each. Each person gets 3 apples."
|
||||||
|
# RESPONSE2 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among her 2 siblings (2 people in total). 9 ÷ 2 = 4.5 apples each. Each person gets 4 apples."
|
||||||
|
|
||||||
|
PROMPT = (
|
||||||
|
"What is the range of the numeric output of a sigmoid node in a neural network?"
|
||||||
|
)
|
||||||
|
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
|
||||||
|
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
|
||||||
|
|
||||||
|
CONVS = [
|
||||||
|
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
|
||||||
|
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRewardModels(unittest.TestCase):
|
||||||
|
|
||||||
|
def assert_close_reward_scores(
|
||||||
|
self,
|
||||||
|
convs,
|
||||||
|
model_path,
|
||||||
|
tp_size,
|
||||||
|
torch_dtype,
|
||||||
|
tolerance,
|
||||||
|
) -> None:
|
||||||
|
with HFRunner(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
model_type="reward",
|
||||||
|
) as hf_runner:
|
||||||
|
hf_outputs = hf_runner.forward(convs)
|
||||||
|
|
||||||
|
with SRTRunner(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
model_type="reward",
|
||||||
|
) as srt_runner:
|
||||||
|
srt_outputs = srt_runner.forward(convs)
|
||||||
|
|
||||||
|
hf_scores = torch.tensor(hf_outputs.scores)
|
||||||
|
srt_scores = torch.tensor(srt_outputs.scores)
|
||||||
|
print(hf_scores)
|
||||||
|
print(srt_scores)
|
||||||
|
|
||||||
|
assert torch.all(
|
||||||
|
abs(hf_scores - srt_scores) < tolerance
|
||||||
|
), "reward scores are not all close"
|
||||||
|
|
||||||
|
def test_reward_scores(self):
|
||||||
|
for model, tp_size, tolerance in MODELS:
|
||||||
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
self.assert_close_reward_scores(
|
||||||
|
CONVS, model, tp_size, torch_dtype, tolerance
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
unittest.main()
|
||||||
@@ -7,6 +7,7 @@ suites = {
|
|||||||
"minimal": [
|
"minimal": [
|
||||||
"models/test_embedding_models.py",
|
"models/test_embedding_models.py",
|
||||||
"models/test_generation_models.py",
|
"models/test_generation_models.py",
|
||||||
|
"models/test_reward_models.py",
|
||||||
"sampling/penaltylib",
|
"sampling/penaltylib",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user