Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -94,7 +94,10 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
self.is_generation = is_generation_model(self.hf_config.architectures)
|
||||
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
|
||||
if server_args.context_length is not None:
|
||||
self.context_len = server_args.context_length
|
||||
|
||||
@@ -94,6 +94,7 @@ class ModelTpServer:
|
||||
context_length=server_args.context_length,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
|
||||
@@ -204,7 +204,7 @@ class ModelRunner:
|
||||
else None
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -522,9 +522,18 @@ class ModelRunner:
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
|
||||
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.pooler(hidden_states, input_metadata)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
@@ -275,6 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -283,11 +285,15 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -333,11 +333,13 @@ def launch_server(
|
||||
start_process = start_controller_process_single
|
||||
else:
|
||||
start_process = start_controller_process_multi
|
||||
|
||||
proc_controller = mp.Process(
|
||||
target=start_process,
|
||||
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||
)
|
||||
proc_controller.start()
|
||||
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
args=(
|
||||
@@ -515,6 +517,7 @@ class Runtime:
|
||||
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
|
||||
proc = mp.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, model_overide_args, pipe_writer),
|
||||
|
||||
@@ -38,6 +38,7 @@ class ServerArgs:
|
||||
quantization: Optional[str] = None
|
||||
served_model_name: Optional[str] = None
|
||||
chat_template: Optional[str] = None
|
||||
is_embedding: bool = False
|
||||
|
||||
# Port
|
||||
host: str = "127.0.0.1"
|
||||
@@ -200,6 +201,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is-embedding",
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
@@ -458,6 +464,11 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
||||
)
|
||||
self.trust_remote_code = False
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.disable_flashinfer = False
|
||||
|
||||
@@ -224,13 +224,18 @@ def is_multimodal_model(model):
|
||||
raise ValueError("unrecognized type")
|
||||
|
||||
|
||||
def is_generation_model(model_architectures):
|
||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
def decode_video_base64(video_base64):
|
||||
|
||||
@@ -14,7 +14,7 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
@@ -63,37 +63,35 @@ class HFRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
is_generation,
|
||||
):
|
||||
self.in_queue = multiprocessing.Queue()
|
||||
self.out_queue = multiprocessing.Queue()
|
||||
self.is_generation = is_generation
|
||||
|
||||
self.model_proc = multiprocessing.Process(
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
|
||||
self.model_proc = mp.Process(
|
||||
target=self.start_model_process,
|
||||
args=(
|
||||
self.in_queue,
|
||||
self.out_queue,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
),
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def start_model_process(
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
||||
):
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.is_generation_model = is_generation_model
|
||||
|
||||
if self.is_generation_model:
|
||||
if self.is_generation:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
else:
|
||||
@@ -107,7 +105,7 @@ class HFRunner:
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
if prompts is not None:
|
||||
if self.is_generation_model:
|
||||
if self.is_generation:
|
||||
output_strs = []
|
||||
prefill_logprobs = []
|
||||
for p in prompts:
|
||||
@@ -171,17 +169,19 @@ class SRTRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
is_generation,
|
||||
tp_size=1,
|
||||
port=5157,
|
||||
):
|
||||
self.is_generation_model = is_generation_model
|
||||
self.is_generation = is_generation
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.7,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -189,7 +189,7 @@ class SRTRunner:
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
):
|
||||
if self.is_generation_model:
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
|
||||
Reference in New Issue
Block a user