[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)
This commit is contained in:
@@ -215,12 +215,11 @@ class EmbeddingReqInput:
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
|
||||
if self.text is not None:
|
||||
is_single = isinstance(self.text, str)
|
||||
self.is_single = isinstance(self.text, str)
|
||||
else:
|
||||
is_single = isinstance(self.input_ids[0], int)
|
||||
self.is_single = is_single
|
||||
self.is_single = isinstance(self.input_ids[0], int)
|
||||
|
||||
if is_single:
|
||||
if self.is_single:
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.sampling_params is None:
|
||||
@@ -254,6 +253,52 @@ class TokenizedEmbeddingReqInput:
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardReqInput:
|
||||
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
||||
conv: Union[List[List[Dict]], List[Dict]]
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
|
||||
is_single: bool = True
|
||||
|
||||
def post_init(self):
|
||||
self.is_single = isinstance(self.conv[0], dict)
|
||||
|
||||
if self.is_single:
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
self.sampling_params["max_new_tokens"] = 1
|
||||
else:
|
||||
# support select operation
|
||||
self.batch_size = len(self.conv)
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
||||
else:
|
||||
if not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list.")
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * self.batch_size
|
||||
for i in range(self.batch_size):
|
||||
self.sampling_params[i]["max_new_tokens"] = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedRewardReqInput:
|
||||
# The request id
|
||||
rid: str
|
||||
# The input text
|
||||
input_text: str
|
||||
# The input token ids
|
||||
input_ids: List[int]
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
# The request id
|
||||
|
||||
@@ -46,8 +46,10 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
RewardReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
@@ -142,7 +144,7 @@ class TokenizerManager:
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
@@ -163,7 +165,7 @@ class TokenizerManager:
|
||||
|
||||
async def _handle_single_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
@@ -173,7 +175,13 @@ class TokenizerManager:
|
||||
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
if obj.input_ids is None:
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
assert self.tokenizer is not None
|
||||
conv = obj.conv if not_use_index else obj.conv[index]
|
||||
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
assert self.tokenizer is not None
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
@@ -269,13 +277,21 @@ class TokenizerManager:
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else: # is embedding
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
assert isinstance(obj, RewardReqInput)
|
||||
tokenized_obj = TokenizedRewardReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
@@ -292,7 +308,7 @@ class TokenizerManager:
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
@@ -329,9 +345,16 @@ class TokenizerManager:
|
||||
rid = obj.rid[index]
|
||||
if parallel_sample_num == 1:
|
||||
## select operation
|
||||
if obj.input_ids is None:
|
||||
if hasattr(obj, "conv"):
|
||||
# reward model
|
||||
conv = obj.conv[i]
|
||||
input_text = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
elif obj.input_ids is None:
|
||||
input_text = obj.text[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
@@ -370,13 +393,21 @@ class TokenizerManager:
|
||||
else obj.lora_path
|
||||
),
|
||||
)
|
||||
else:
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
assert isinstance(obj, RewardReqInput)
|
||||
tokenized_obj = TokenizedRewardReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
@@ -442,7 +473,7 @@ class TokenizerManager:
|
||||
async def _wait_for_response(
|
||||
self,
|
||||
state: ReqState,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
||||
rid: str,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
@@ -469,7 +500,7 @@ class TokenizerManager:
|
||||
),
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
else: # isinstance(obj, EmbeddingReqInput)
|
||||
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
|
||||
out = state.out_list[-1]
|
||||
|
||||
out["index"] = response_index
|
||||
|
||||
@@ -22,7 +22,7 @@ import os
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
TokenizedRewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
@@ -223,7 +224,9 @@ class ModelTpServer:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
self.do_not_get_new_batch = False
|
||||
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
||||
elif isinstance(
|
||||
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
||||
):
|
||||
self.handle_embedding_request(recv_req)
|
||||
self.do_not_get_new_batch = False
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
@@ -407,7 +410,7 @@ class ModelTpServer:
|
||||
|
||||
def handle_embedding_request(
|
||||
self,
|
||||
recv_req: TokenizedEmbeddingReqInput,
|
||||
recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
|
||||
):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
142
python/sglang/srt/models/llama_reward.py
Normal file
142
python/sglang/srt/models/llama_reward.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
class LlamaForSequenceClassification(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.torchao_config = None
|
||||
self.quant_config = quant_config
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
scores = self.score(hidden_states)
|
||||
|
||||
return self.pooler(scores, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
elif "lm_head" in name:
|
||||
continue
|
||||
else:
|
||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||
|
||||
|
||||
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
|
||||
class Weights(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_label):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
|
||||
torch.nn.SELU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
|
||||
torch.nn.SELU(),
|
||||
torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x.to(torch.float16))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config, cache_config)
|
||||
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "LlamaForSequenceClassification is only used for embedding"
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits = self.score(hidden_states)
|
||||
weights = self.weights(hidden_states)
|
||||
|
||||
pooled_logits = self.pooler(logits, input_metadata).embeddings
|
||||
pooled_weights = self.pooler(weights, input_metadata).embeddings
|
||||
|
||||
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
|
||||
-1, self.num_labels // 2
|
||||
)
|
||||
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
|
||||
return EmbeddingPoolerOutput(scores)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
elif "lm_head" in name:
|
||||
continue
|
||||
else:
|
||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||
|
||||
|
||||
EntryClass = [
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForSequenceClassificationWithNormal_Weights,
|
||||
]
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
RewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -213,6 +214,21 @@ app.post("/encode")(encode_request)
|
||||
app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
async def judge_request(obj: RewardReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
app.post("/judge")(judge_request)
|
||||
app.put("/judge")(judge_request)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
return await v1_completions(tokenizer_manager, raw_request)
|
||||
@@ -635,15 +651,26 @@ class Runtime:
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/encode",
|
||||
json=json_data,
|
||||
)
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
# embedding
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/encode",
|
||||
json=json_data,
|
||||
)
|
||||
else:
|
||||
# reward
|
||||
json_data = {
|
||||
"conv": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/judge",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@@ -219,6 +219,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
|
||||
@@ -65,6 +65,7 @@ class ModelOutput:
|
||||
top_input_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprobs: List[torch.Tensor] = None
|
||||
embed_logits: List[torch.Tensor] = None
|
||||
scores: List[float] = None
|
||||
|
||||
|
||||
class HFRunner:
|
||||
@@ -72,10 +73,10 @@ class HFRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
model_type="generation",
|
||||
output_str_only=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.model_type = model_type
|
||||
self.output_str_only = output_str_only
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
@@ -92,22 +93,41 @@ class HFRunner:
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def needs_trust_remote_code(self, model_path):
|
||||
models_needs_trust_remote = [
|
||||
"LxzGordon/URM-LLaMa-3.1-8B",
|
||||
]
|
||||
if model_path in models_needs_trust_remote:
|
||||
return True
|
||||
return False
|
||||
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
self.tokenizer = get_tokenizer(model_path)
|
||||
if self.is_generation:
|
||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||
|
||||
if self.model_type == "generation":
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
else:
|
||||
elif self.model_type == "embedding":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
)
|
||||
).cuda()
|
||||
elif self.model_type == "reward":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=self.needs_trust_remote_code(model_path),
|
||||
).cuda()
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
@@ -115,7 +135,7 @@ class HFRunner:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
if prompts is not None:
|
||||
if self.is_generation:
|
||||
if self.model_type == "generation":
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
@@ -179,11 +199,27 @@ class HFRunner:
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
elif self.model_type == "embedding":
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
elif self.model_type == "reward":
|
||||
scores = []
|
||||
for conv in prompts:
|
||||
conv_formatted = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
conv_tokenized = self.tokenizer(
|
||||
conv_formatted, return_tensors="pt"
|
||||
).to("cuda")
|
||||
scores.append(
|
||||
float(self.model(**conv_tokenized).logits[0][0].item())
|
||||
)
|
||||
out_queue.put(ModelOutput(scores=scores))
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
@@ -210,7 +246,7 @@ class SRTRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
model_type,
|
||||
tp_size=1,
|
||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths=None,
|
||||
@@ -218,13 +254,14 @@ class SRTRunner:
|
||||
disable_cuda_graph=False,
|
||||
disable_radix_cache=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.69,
|
||||
mem_fraction_static=0.65,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
@@ -285,8 +322,12 @@ class SRTRunner:
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
else:
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=scores)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
@@ -316,8 +357,12 @@ class SRTRunner:
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
else:
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user