[Feature] Support reward model LxzGordon/URM-LLaMa-3.1-8B (#1525)
This commit is contained in:
@@ -65,6 +65,7 @@ class ModelOutput:
|
||||
top_input_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprobs: List[torch.Tensor] = None
|
||||
embed_logits: List[torch.Tensor] = None
|
||||
scores: List[float] = None
|
||||
|
||||
|
||||
class HFRunner:
|
||||
@@ -72,10 +73,10 @@ class HFRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
model_type="generation",
|
||||
output_str_only=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.model_type = model_type
|
||||
self.output_str_only = output_str_only
|
||||
|
||||
self.in_queue = mp.Queue()
|
||||
@@ -92,22 +93,41 @@ class HFRunner:
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def needs_trust_remote_code(self, model_path):
|
||||
models_needs_trust_remote = [
|
||||
"LxzGordon/URM-LLaMa-3.1-8B",
|
||||
]
|
||||
if model_path in models_needs_trust_remote:
|
||||
return True
|
||||
return False
|
||||
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
self.tokenizer = get_tokenizer(model_path)
|
||||
if self.is_generation:
|
||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||
|
||||
if self.model_type == "generation":
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
else:
|
||||
elif self.model_type == "embedding":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
)
|
||||
).cuda()
|
||||
elif self.model_type == "reward":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=self.needs_trust_remote_code(model_path),
|
||||
).cuda()
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
@@ -115,7 +135,7 @@ class HFRunner:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
if prompts is not None:
|
||||
if self.is_generation:
|
||||
if self.model_type == "generation":
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
@@ -179,11 +199,27 @@ class HFRunner:
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
elif self.model_type == "embedding":
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
elif self.model_type == "reward":
|
||||
scores = []
|
||||
for conv in prompts:
|
||||
conv_formatted = self.tokenizer.apply_chat_template(
|
||||
conv, tokenize=False
|
||||
)
|
||||
conv_tokenized = self.tokenizer(
|
||||
conv_formatted, return_tensors="pt"
|
||||
).to("cuda")
|
||||
scores.append(
|
||||
float(self.model(**conv_tokenized).logits[0][0].item())
|
||||
)
|
||||
out_queue.put(ModelOutput(scores=scores))
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
@@ -210,7 +246,7 @@ class SRTRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation,
|
||||
model_type,
|
||||
tp_size=1,
|
||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths=None,
|
||||
@@ -218,13 +254,14 @@ class SRTRunner:
|
||||
disable_cuda_graph=False,
|
||||
disable_radix_cache=False,
|
||||
):
|
||||
self.is_generation = is_generation
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.69,
|
||||
mem_fraction_static=0.65,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
@@ -285,8 +322,12 @@ class SRTRunner:
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
else:
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=scores)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
@@ -316,8 +357,12 @@ class SRTRunner:
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
else:
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user