diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index edf7d75ed..e43caf5f0 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -35,6 +35,7 @@ jobs: pip install -e "python[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install accelerate + pip install sentence_transformers - name: Test Frontend Language run: | diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 0bd03d314..623ffe916 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -25,7 +25,11 @@ import zmq import zmq.asyncio from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut +from sglang.srt.managers.io_struct import ( + BatchEmbeddingOut, + BatchStrOut, + BatchTokenIDOut, +) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry @@ -66,6 +70,18 @@ class DetokenizerManager: async def handle_loop(self): while True: recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() + + if isinstance(recv_obj, BatchEmbeddingOut): + self.send_to_tokenizer.send_pyobj( + BatchEmbeddingOut( + rids=recv_obj.rids, + embeddings=recv_obj.embeddings, + meta_info=recv_obj.meta_info, + finished_reason=recv_obj.finished_reason, + ) + ) + continue + assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4f89ba3b9..e7c5cba92 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -143,6 +143,7 @@ class Req: # Logprobs self.return_logprob = False + self.embedding = None self.logprob_start_len = 0 self.top_logprobs_num = 0 self.normalized_prompt_logprob = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d2bf24c85..8711c127d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -21,7 +21,7 @@ import dataclasses import logging import multiprocessing as mp import os -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import numpy as np import transformers @@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.managers.io_struct import ( AbortReq, + BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import is_multimodal_model, load_image +from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -85,6 +88,7 @@ class TokenizerManager: trust_remote_code=server_args.trust_remote_code, model_overide_args=model_overide_args, ) + self.is_generation = is_generation_model(self.hf_config.architectures) if server_args.context_length is not None: self.context_len = server_args.context_length @@ -133,7 +137,9 @@ class TokenizerManager: image_data, aspect_ratio, grid_pinpoints, self.processor ) - async def generate_request(self, obj: GenerateReqInput, request=None): + async def generate_request( + self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None + ): if self.to_create_loop: self.create_handle_loop() @@ -144,6 +150,8 @@ class TokenizerManager: async for response in self._handle_single_request(obj, request): yield response else: + if isinstance(obj, EmbeddingReqInput): + raise NotImplementedError("Please send only one prompt in each request") if obj.stream: raise ValueError("Do not support stream for batch mode.") @@ -151,39 +159,47 @@ class TokenizerManager: yield response async def _handle_single_request( - self, obj, request, index=None, is_cache_for_prefill=False + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request, + index=None, + is_cache_for_prefill=False, ): if not is_cache_for_prefill: # The normal case with a single prompt not_use_index = index is None rid = obj.rid if not_use_index else obj.rid[index] input_text = obj.text if not_use_index else obj.text[index] - input_ids = ( - self.tokenizer.encode(input_text) - if obj.input_ids is None - else obj.input_ids - ) - if not not_use_index and obj.input_ids: - input_ids = obj.input_ids[index] + if obj.input_ids is None: + input_ids = self.tokenizer.encode(input_text) + else: + input_ids = obj.input_ids if not_use_index else obj.input_ids[index] self._validate_input_length(input_ids) sampling_params = self._get_sampling_params( obj.sampling_params if not_use_index else obj.sampling_params[index] ) - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data if not_use_index else obj.image_data[index] - ) - return_logprob = ( - obj.return_logprob if not_use_index else obj.return_logprob[index] - ) - logprob_start_len = ( - obj.logprob_start_len if not_use_index else obj.logprob_start_len[index] - ) - top_logprobs_num = ( - obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index] - ) + + if self.is_generation: + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data if not_use_index else obj.image_data[index] + ) + return_logprob = ( + obj.return_logprob if not_use_index else obj.return_logprob[index] + ) + logprob_start_len = ( + obj.logprob_start_len + if not_use_index + else obj.logprob_start_len[index] + ) + top_logprobs_num = ( + obj.top_logprobs_num + if not_use_index + else obj.top_logprobs_num[index] + ) else: # A prefill request to cache the common prompt for parallel sampling + assert self.is_generation if obj.text is not None: if isinstance(obj.text, list): input_text = obj.text[index] @@ -213,19 +229,28 @@ class TokenizerManager: logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] - tokenized_obj = TokenizedGenerateReqInput( - rid, - input_text, - input_ids, - pixel_values, - image_hash, - image_size, - sampling_params, - return_logprob, - logprob_start_len, - top_logprobs_num, - obj.stream, - ) + if self.is_generation: + tokenized_obj = TokenizedGenerateReqInput( + rid, + input_text, + input_ids, + pixel_values, + image_hash, + image_size, + sampling_params, + return_logprob, + logprob_start_len, + top_logprobs_num, + obj.stream, + ) + else: # is embedding + tokenized_obj = TokenizedEmbeddingReqInput( + rid, + input_text, + input_ids, + sampling_params, + ) + self.send_to_router.send_pyobj(tokenized_obj) event = asyncio.Event() @@ -368,7 +393,7 @@ class TokenizerManager: self, event: asyncio.Event, state: ReqState, - obj: GenerateReqInput, + obj: Union[GenerateReqInput, EmbeddingReqInput], rid: str, request, ): @@ -381,12 +406,15 @@ class TokenizerManager: raise ValueError(f"Abort request {rid}") continue - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, - ) + if self.is_generation: + out = self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob, + obj.top_logprobs_num, + obj.return_text_in_logprobs, + ) + else: # isinstance(obj, EmbeddingReqInput) + out = state.out_list[-1] # Log requests if self.server_args.log_requests and state.finished: @@ -459,8 +487,10 @@ class TokenizerManager: async def handle_loop(self): while True: - recv_obj: BatchStrOut = await self.recv_from_detokenizer.recv_pyobj() - assert isinstance(recv_obj, BatchStrOut) + recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = ( + await self.recv_from_detokenizer.recv_pyobj() + ) + assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut)) for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) @@ -468,10 +498,17 @@ class TokenizerManager: continue recv_obj.meta_info[i]["id"] = rid - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], - } + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": recv_obj.meta_info[i], + } + else: + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": recv_obj.meta_info[i], + } state.out_list.append(out_dict) state.finished = recv_obj.finished_reason[i] is not None state.event.set() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index e47c9e955..77941c8af 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -20,7 +20,7 @@ import multiprocessing import pickle import time import warnings -from typing import List, Optional +from typing import List, Optional, Union import torch import torch.distributed as dist @@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( AbortReq, + BatchEmbeddingOut, BatchTokenIDOut, FlushCacheReq, + TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder @@ -205,7 +207,9 @@ class ModelTpServer: try: # Recv requests for recv_req in recv_reqs: - if isinstance(recv_req, TokenizedGenerateReqInput): + if isinstance( + recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ): self.handle_generate_request(recv_req) elif isinstance(recv_req, FlushCacheReq): self.flush_cache() @@ -297,41 +301,42 @@ class ModelTpServer: def handle_generate_request( self, - recv_req: TokenizedGenerateReqInput, + recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], ): req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) - req.pixel_values = recv_req.pixel_values - if req.pixel_values is not None: - req.pad_value = [ - (recv_req.image_hash) % self.model_config.vocab_size, - (recv_req.image_hash >> 16) % self.model_config.vocab_size, - (recv_req.image_hash >> 32) % self.model_config.vocab_size, - (recv_req.image_hash >> 64) % self.model_config.vocab_size, - ] - req.image_size = recv_req.image_size - ( - req.origin_input_ids, - req.image_offset, - ) = self.model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values.shape, - req.image_size, - ) - req.sampling_params = recv_req.sampling_params - req.return_logprob = recv_req.return_logprob - req.logprob_start_len = recv_req.logprob_start_len - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream req.tokenizer = self.tokenizer - - # Init regex fsm - if req.sampling_params.regex is not None: - req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - req.sampling_params.regex + req.sampling_params = recv_req.sampling_params + if self.model_runner.is_generation: + req.pixel_values = recv_req.pixel_values + if req.pixel_values is not None: + req.pad_value = [ + (recv_req.image_hash) % self.model_config.vocab_size, + (recv_req.image_hash >> 16) % self.model_config.vocab_size, + (recv_req.image_hash >> 32) % self.model_config.vocab_size, + (recv_req.image_hash >> 64) % self.model_config.vocab_size, + ] + req.image_size = recv_req.image_size + ( + req.origin_input_ids, + req.image_offset, + ) = self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values.shape, + req.image_size, ) + req.return_logprob = recv_req.return_logprob + req.logprob_start_len = recv_req.logprob_start_len + req.top_logprobs_num = recv_req.top_logprobs_num + req.stream = recv_req.stream + + # Init regex fsm + if req.sampling_params.regex is not None: + req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + req.sampling_params.regex + ) # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -340,14 +345,17 @@ class ModelTpServer: "the max context length. Truncated!!!" ) req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] - req.sampling_params.max_new_tokens = min( - ( - req.sampling_params.max_new_tokens - if req.sampling_params.max_new_tokens is not None - else 1 << 30 - ), - self.max_req_input_len - 1 - len(req.origin_input_ids), - ) + + if self.model_runner.is_generation: + req.sampling_params.max_new_tokens = min( + ( + req.sampling_params.max_new_tokens + if req.sampling_params.max_new_tokens is not None + else 1 << 30 + ), + self.max_req_input_len - 1 - len(req.origin_input_ids), + ) + self.waiting_queue.append(req) def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: @@ -439,47 +447,68 @@ class ModelTpServer: self.model_config.vocab_size, self.int_token_logit_bias ) - # Forward and sample the next tokens - if batch.extend_num_tokens != 0: - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) + if self.model_runner.is_generation: + # Forward and sample the next tokens + if batch.extend_num_tokens != 0: + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = batch.sample(output.next_token_logits) - # Move logprobs to cpu - if output.next_token_logprobs is not None: - output.next_token_logprobs = output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - output.input_token_logprobs = output.input_token_logprobs.tolist() - output.normalized_prompt_logprobs = ( - output.normalized_prompt_logprobs.tolist() - ) + # Move logprobs to cpu + if output.next_token_logprobs is not None: + output.next_token_logprobs = output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), + next_token_ids, + ].tolist() + output.input_token_logprobs = output.input_token_logprobs.tolist() + output.normalized_prompt_logprobs = ( + output.normalized_prompt_logprobs.tolist() + ) - next_token_ids = next_token_ids.tolist() - else: - next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - - # Check finish conditions - pt = 0 - for i, req in enumerate(batch.reqs): - if req is not self.current_inflight_req: - # Inflight reqs' prefill is not finished - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_ids[i]) - req.check_finished() - - if req.finished(): - self.tree_cache.cache_finished_req(req) + next_token_ids = next_token_ids.tolist() else: - self.tree_cache.cache_unfinished_req(req) + next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - if req is self.current_inflight_req: - # Inflight request would get a new req idx - self.req_to_token_pool.free(req.req_pool_idx) + # Check finish conditions + pt = 0 + for i, req in enumerate(batch.reqs): + if req is not self.current_inflight_req: + # Inflight reqs' prefill is not finished + req.completion_tokens_wo_jump_forward += 1 + req.output_ids.append(next_token_ids[i]) + req.check_finished() - if req.return_logprob: - self.add_logprob_return_values(i, req, pt, next_token_ids, output) - pt += req.extend_input_len + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + + if req is self.current_inflight_req: + # Inflight request would get a new req idx + self.req_to_token_pool.free(req.req_pool_idx) + + if req.return_logprob: + self.add_logprob_return_values(i, req, pt, next_token_ids, output) + pt += req.extend_input_len + else: + assert batch.extend_num_tokens != 0 + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + embeddings = output.embeddings.tolist() + + # Check finish conditions + for i, req in enumerate(batch.reqs): + req.embedding = embeddings[i] + if req is not self.current_inflight_req: + # Inflight reqs' prefill is not finished + req.check_finished() + + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + + if req is self.current_inflight_req: + # Inflight request would get a new req idx + self.req_to_token_pool.free(req.req_pool_idx) self.handle_finished_requests(batch) @@ -596,15 +625,19 @@ class ModelTpServer: def handle_finished_requests(self, batch: ScheduleBatch): output_rids = [] - output_vids = [] - decoded_texts = [] - output_read_ids = [] - output_read_offsets = [] - output_skip_special_tokens = [] - output_spaces_between_special_tokens = [] output_meta_info = [] output_finished_reason: List[BaseFinishReason] = [] + if self.model_runner.is_generation: + output_vids = [] + decoded_texts = [] + output_read_ids = [] + output_read_offsets = [] + output_skip_special_tokens = [] + output_spaces_between_special_tokens = [] + else: # for embedding model + output_embeddings = [] unfinished_indices = [] + for i, req in enumerate(batch.reqs): if not req.finished() and req is not self.current_inflight_req: unfinished_indices.append(i) @@ -619,56 +652,73 @@ class ModelTpServer: ) ): output_rids.append(req.rid) - output_vids.append(req.vid) - decoded_texts.append(req.decoded_text) - read_ids, read_offset = req.init_incremental_detokenize() - output_read_ids.append(read_ids) - output_read_offsets.append(read_offset) - output_skip_special_tokens.append( - req.sampling_params.skip_special_tokens - ) - output_spaces_between_special_tokens.append( - req.sampling_params.spaces_between_special_tokens - ) - - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.output_ids), - "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": str(req.finished_reason), - } - if req.return_logprob: - ( - meta_info["input_token_logprobs"], - meta_info["output_token_logprobs"], - meta_info["input_top_logprobs"], - meta_info["output_top_logprobs"], - meta_info["normalized_prompt_logprob"], - ) = ( - req.input_token_logprobs, - req.output_token_logprobs, - req.input_top_logprobs, - req.output_top_logprobs, - req.normalized_prompt_logprob, - ) - output_meta_info.append(meta_info) output_finished_reason.append(req.finished_reason) + if self.model_runner.is_generation: + output_vids.append(req.vid) + decoded_texts.append(req.decoded_text) + read_ids, read_offset = req.init_incremental_detokenize() + output_read_ids.append(read_ids) + output_read_offsets.append(read_offset) + output_skip_special_tokens.append( + req.sampling_params.skip_special_tokens + ) + output_spaces_between_special_tokens.append( + req.sampling_params.spaces_between_special_tokens + ) + + meta_info = { + "prompt_tokens": len(req.origin_input_ids), + "completion_tokens": len(req.output_ids), + "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, + "finish_reason": str(req.finished_reason), + } + if req.return_logprob: + ( + meta_info["input_token_logprobs"], + meta_info["output_token_logprobs"], + meta_info["input_top_logprobs"], + meta_info["output_top_logprobs"], + meta_info["normalized_prompt_logprob"], + ) = ( + req.input_token_logprobs, + req.output_token_logprobs, + req.input_top_logprobs, + req.output_top_logprobs, + req.normalized_prompt_logprob, + ) + output_meta_info.append(meta_info) + else: # for embedding model + output_embeddings.append(req.embedding) + meta_info = { + "prompt_tokens": len(req.origin_input_ids), + } + output_meta_info.append(meta_info) # Send to detokenizer if output_rids: - self.out_pyobjs.append( - BatchTokenIDOut( - output_rids, - output_vids, - decoded_texts, - output_read_ids, - output_read_offsets, - output_skip_special_tokens, - output_spaces_between_special_tokens, - output_meta_info, - output_finished_reason, + if self.model_runner.is_generation: + self.out_pyobjs.append( + BatchTokenIDOut( + output_rids, + output_vids, + decoded_texts, + output_read_ids, + output_read_offsets, + output_skip_special_tokens, + output_spaces_between_special_tokens, + output_meta_info, + output_finished_reason, + ) + ) + else: # for embedding model + self.out_pyobjs.append( + BatchEmbeddingOut( + output_rids, + output_embeddings, + output_meta_info, + output_finished_reason, + ) ) - ) # Remove finished reqs: update batch tensors batch.filter_batch(unfinished_indices) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 17ce5edf7..574ad3658 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, + is_generation_model, is_llama3_405b_fp8, is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, @@ -132,8 +133,10 @@ class ModelRunner: self.init_cublas() self.init_flashinfer() - # Capture cuda graphs - self.init_cuda_graphs() + if self.is_generation: + # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models + # Capture cuda graphs + self.init_cuda_graphs() def load_model(self): logger.info( @@ -184,6 +187,10 @@ class ModelRunner: scheduler_config=None, cache_config=None, ) + self.is_generation = is_generation_model( + self.model_config.hf_config.architectures + ) + logger.info( f"[gpu={self.gpu_id}] Load weight end. " f"type={type(self.model).__name__}, " @@ -406,8 +413,10 @@ def import_model_classes(): entry, list ): # To support multiple model classes in one module for tmp in entry: + assert tmp.__name__ not in model_arch_name_to_cls model_arch_name_to_cls[tmp.__name__] = tmp else: + assert entry.__name__ not in model_arch_name_to_cls model_arch_name_to_cls[entry.__name__] = entry # compat: some models such as chatglm has incorrect class set in config.json @@ -417,6 +426,7 @@ def import_model_classes(): ): for remap in module.EntryClassRemapping: if isinstance(remap, tuple) and len(remap) == 2: + assert remap[0] not in model_arch_name_to_cls model_arch_name_to_cls[remap[0]] = remap[1] return model_arch_name_to_cls diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index b849a4b51..e8e678047 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module): EntryClass = LlamaEmbeddingModel +# compat: e5-mistral model.config class == MistralModel +EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0443e9f2a..ee84a99e4 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import ( start_controller_process as start_controller_process_single, ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, @@ -97,6 +97,7 @@ async def health() -> Response: async def get_model_info(): result = { "model_path": tokenizer_manager.model_path, + "is_generation": tokenizer_manager.is_generation, } return result @@ -148,6 +149,21 @@ app.post("/generate")(generate_request) app.put("/generate")(generate_request) +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +app.post("/encode")(encode_request) +app.put("/encode")(encode_request) + + @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): return await v1_completions(tokenizer_manager, raw_request) @@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): except (AssertionError, requests.exceptions.RequestException) as e: last_traceback = get_exception_traceback() pass + model_info = res.json() if not success: if pipe_finish_writer is not None: @@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer): sys.exit(1) # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 0 try: for _ in range(server_args.dp_size): res = requests.post( - url + "/generate", + url + request_name, json={ "text": "The capital city of France is", "sampling_params": { "temperature": 0, - "max_new_tokens": 8, + "max_new_tokens": max_new_tokens, }, }, headers=headers, @@ -529,5 +548,18 @@ class Runtime: ) return json.dumps(response.json()) + def encode( + self, + prompt: str, + ): + json_data = { + "text": prompt, + } + response = requests.post( + self.url + "/encode", + json=json_data, + ) + return json.dumps(response.json()) + def __del__(self): self.shutdown() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e15cb6751..525ae8ca7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -223,6 +223,15 @@ def is_multimodal_model(model): raise ValueError("unrecognized type") +def is_generation_model(model_architectures): + if ( + "LlamaEmbeddingModel" in model_architectures + or "MistralModel" in model_architectures + ): + return False + return True + + def decode_video_base64(video_base64): from PIL import Image diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 3a8cff213..87277ca69 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer from sglang.srt.server import Runtime +from sglang.srt.utils import is_generation_model DEFAULT_PROMPTS = [ "The capital of France is", @@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [ NUM_TOP_LOGPROBS = 5 -def is_embedding_model(model_path): - # FIXME incomplete list - if "e5-mistral-7b-instruct" in model_path.lower(): - return True - return False - - def get_dtype_str(torch_dtype): if torch_dtype is torch.float16: return "float16" @@ -60,7 +54,7 @@ class HFRunner: self, model_path, torch_dtype=torch.float16, - is_embedding_model=None, + is_generation_model=None, ): self.in_queue = multiprocessing.Queue() self.out_queue = multiprocessing.Queue() @@ -72,13 +66,13 @@ class HFRunner: self.out_queue, model_path, torch_dtype, - is_embedding_model, + is_generation_model, ), ) self.model_proc.start() def start_model_process( - self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model + self, in_queue, out_queue, model_path, torch_dtype, is_generation_model ): self.tokenizer = AutoTokenizer.from_pretrained( model_path, @@ -86,12 +80,12 @@ class HFRunner: trust_remote_code=True, ) - self.is_embedding_model = ( - is_embedding_model(model_path) - if is_embedding_model is None - else is_embedding_model + self.is_generation_model = ( + is_generation_model(model_path) + if is_generation_model is None + else is_generation_model ) - if not self.is_embedding_model: + if self.is_generation_model: self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, @@ -103,13 +97,13 @@ class HFRunner: self.model = SentenceTransformer( model_path, - device="cpu", - ).to(dtype=torch_dtype) + model_kwargs={"torch_dtype": torch_dtype}, + ) while True: prompts, max_new_tokens = in_queue.get() if prompts is not None: - if not self.is_embedding_model: + if self.is_generation_model: output_strs = [] prefill_logprobs = [] for p in prompts: @@ -144,7 +138,6 @@ class HFRunner: ) else: - assert isinstance(prompts, List[str]) logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) @@ -175,16 +168,13 @@ class SRTRunner: model_path, tp_size=1, torch_dtype=torch.float16, - is_embedding_model=None, + is_generation_model=None, ): - self.is_embedding_model = ( - is_embedding_model(model_path) - if is_embedding_model is None - else is_embedding_model + self.is_generation_model = ( + is_generation_model(model_path) + if is_generation_model is None + else is_generation_model ) - if self.is_embedding_model: - raise NotImplementedError() - self.runtime = Runtime( model_path=model_path, tp_size=tp_size, @@ -196,38 +186,45 @@ class SRTRunner: prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=64, ): - # the return value contains logprobs from prefill - output_strs = [] - top_input_logprobs = [] - sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} - for prompt in prompts: - response = self.runtime.generate( - prompt, - sampling_params=sampling_params, - return_logprob=True, - top_logprobs_num=NUM_TOP_LOGPROBS, - ) - response = json.loads(response) - output_strs.append(response["text"]) - top_input_logprobs.append( - [ - [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] - for x in response["meta_info"]["input_top_logprobs"][1:] - ] - + [ + if self.is_generation_model: + # the return value contains logprobs from prefill + output_strs = [] + top_input_logprobs = [] + sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} + for prompt in prompts: + response = self.runtime.generate( + prompt, + sampling_params=sampling_params, + return_logprob=True, + top_logprobs_num=NUM_TOP_LOGPROBS, + ) + response = json.loads(response) + output_strs.append(response["text"]) + top_input_logprobs.append( [ - tup[0] - for tup in response["meta_info"]["output_top_logprobs"][0][ - :NUM_TOP_LOGPROBS + [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] + for x in response["meta_info"]["input_top_logprobs"][1:] + ] + + [ + [ + tup[0] + for tup in response["meta_info"]["output_top_logprobs"][0][ + :NUM_TOP_LOGPROBS + ] ] ] - ] - ) - # print(response["meta_info"]["output_top_logprobs"][0]) + ) - return ModelOutput( - output_strs=output_strs, top_input_logprobs=top_input_logprobs - ) + return ModelOutput( + output_strs=output_strs, top_input_logprobs=top_input_logprobs + ) + else: + logits = [] + for prompt in prompts: + response = self.runtime.encode(prompt) + response = json.loads(response) + logits.append(response["embedding"]) + return ModelOutput(embed_logits=logits) def __enter__(self): return self diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index c6212dc39..613645b57 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -12,6 +12,8 @@ from typing import Callable, List, Optional import numpy as np import requests +import torch +import torch.nn.functional as F from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI @@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") return 0 if success else -1 + + +def get_similarities(vec1, vec2): + return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0) diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py new file mode 100644 index 000000000..c29c33188 --- /dev/null +++ b/test/srt/models/test_embedding_models.py @@ -0,0 +1,69 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +import torch + +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner +from sglang.test.test_utils import get_similarities + +MODELS = [("intfloat/e5-mistral-7b-instruct", 1)] +TORCH_DTYPES = [torch.float16] + + +class TestEmbeddingModels(unittest.TestCase): + + def assert_close_prefill_logits( + self, + prompts, + model_path, + tp_size, + torch_dtype, + ) -> None: + with HFRunner( + model_path, torch_dtype=torch_dtype, is_generation_model=False + ) as hf_runner: + hf_outputs = hf_runner.forward(prompts) + + with SRTRunner( + model_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation_model=False, + ) as srt_runner: + srt_outputs = srt_runner.forward(prompts) + + for i in range(len(prompts)): + hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) + srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) + + similarities = torch.tensor(get_similarities(hf_logits, srt_logits)) + + tolerance = 1e-2 + assert torch.all( + abs(similarities - 1) < tolerance + ), f"embeddings not all close" + + def test_prefill_logits(self): + for model, tp_size in MODELS: + for torch_dtype in TORCH_DTYPES: + self.assert_close_prefill_logits( + DEFAULT_PROMPTS, model, tp_size, torch_dtype + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/srt/models/test_causal_models.py b/test/srt/models/test_generation_models.py similarity index 94% rename from test/srt/models/test_causal_models.py rename to test/srt/models/test_generation_models.py index 4aeaadb99..f05764802 100644 --- a/test/srt/models/test_causal_models.py +++ b/test/srt/models/test_generation_models.py @@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase): torch_dtype, ) -> None: with HFRunner( - model_path, torch_dtype=torch_dtype, is_embedding_model=False + model_path, torch_dtype=torch_dtype, is_generation_model=True ) as hf_runner: hf_outputs = hf_runner.forward(prompts) @@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase): model_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_embedding_model=False, + is_generation_model=True, ) as srt_runner: srt_outputs = srt_runner.forward(prompts) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index edb8db316..67d772b30 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -10,7 +10,8 @@ suites = { "test_vision_openai_server.py", "test_chunked_prefill.py", "test_torch_compile.py", - "models/test_causal_models.py", + "models/test_generation_models.py", + "models/test_embedding_models.py", "sampling/penaltylib", ], "sampling/penaltylib": glob.glob(