diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 7a8751043..b1752f89e 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [ def get_context_length(config): - """Get the context length of a model from a huggingface model configs. - And here the config should be text_config part if the model is a multimodal - LLM. - """ - text_config = getattr(config, "text_config", config) + """Get the context length of a model from a huggingface model configs.""" + text_config = config rope_scaling = getattr(text_config, "rope_scaling", None) if rope_scaling: rope_scaling_factor = rope_scaling.get("factor", 1) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 339638c0a..9c5ed14f3 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -238,7 +238,7 @@ class EmbeddingReqInput: self.rid = uuid.uuid4().hex if self.sampling_params is None: self.sampling_params = {} - self.sampling_params["max_new_tokens"] = 1 + self.sampling_params["max_new_tokens"] = 0 else: if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] @@ -248,7 +248,7 @@ class EmbeddingReqInput: if self.sampling_params is None: self.sampling_params = [{}] * self.batch_size for i in range(self.batch_size): - self.sampling_params[i]["max_new_tokens"] = 1 + self.sampling_params[i]["max_new_tokens"] = 0 def regenerate_rid(self): self.rid = uuid.uuid4().hex diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index a9eaa81c1..a6f460846 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -34,6 +34,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ @@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module): self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() def forward( @@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + get_embedding: bool = False, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch - ) + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) def get_hidden_dim(self, module_name): # return input_dim, output_dim diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index 19e324f92..da43d03fc 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.pooler(hidden_states, forward_batch) - def load_weights( - self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module): ] params_dict = dict(self.model.named_parameters()) - def load_weights_per_param(name, loaded_weight): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: return if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: @@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if name is None or loaded_weight is None: - for name, loaded_weight in weights: - load_weights_per_param(name, loaded_weight) - else: - load_weights_per_param(name, loaded_weight) - class MistralModel(LlamaEmbeddingModel): pass diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 2e9c0457f..5b68d1d32 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + get_embedding: bool = True, ) -> EmbeddingPoolerOutput: + assert ( + get_embedding + ), "LlamaForSequenceClassification is only used for embedding" + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) scores = self.score(hidden_states) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 59adb2ee7..82412a51a 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() for i, image in enumerate(forward_batch.image_inputs): - if image == None: + if image is None: continue start_idx = extend_start_loc_cpu[i] prefix_len = prefix_lens_cpu[i] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 854a0c658..d96176b2d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -254,7 +254,7 @@ app.put("/encode")(encode_request) async def judge_request(obj: EmbeddingReqInput, request: Request): - """Handle a reward model request.""" + """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" try: ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret @@ -696,24 +696,8 @@ class Runtime: self, prompt: Union[str, List[str], List[Dict], List[List[Dict]]], ): - if isinstance(prompt, str) or isinstance(prompt[0], str): - # embedding - json_data = { - "text": prompt, - } - response = requests.post( - self.url + "/encode", - json=json_data, - ) - else: - # reward - json_data = { - "conv": prompt, - } - response = requests.post( - self.url + "/judge", - json=json_data, - ) + json_data = {"text": prompt} + response = requests.post(self.url + "/encode", json=json_data) return json.dumps(response.json()) def __del__(self): diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 217065bd2..3870c4503 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -273,6 +273,7 @@ class SRTRunner: disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, ) + self.tokenizer = get_tokenizer(model_path) def forward( self, @@ -366,7 +367,7 @@ class SRTRunner: return ModelOutput(embed_logits=logits) else: scores = [x["embedding"][0] for x in response] - return ModelOutput(scores=logits) + return ModelOutput(scores=scores) def __enter__(self): return self diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index 3ad187cbb..aefe4f3e7 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16] class TestEmbeddingModels(unittest.TestCase): + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + def assert_close_prefill_logits( self, prompts, @@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase): if __name__ == "__main__": - try: - mp.set_start_method("spawn") - except RuntimeError: - pass - unittest.main() diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 1d32b8af1..b4c2cde2d 100755 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16] class TestGenerationModels(unittest.TestCase): + @classmethod def setUpClass(cls): - mp.set_start_method("spawn") + mp.set_start_method("spawn", force=True) def assert_close_logits_and_output_strs( self, diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py index cd15b4967..499f7822a 100644 --- a/test/srt/models/test_reward_models.py +++ b/test/srt/models/test_reward_models.py @@ -18,10 +18,10 @@ import unittest import torch -from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner +from sglang.test.runners import HFRunner, SRTRunner MODELS = [ - ("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2), + ("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2), ] TORCH_DTYPES = [torch.float16] @@ -43,6 +43,10 @@ CONVS = [ class TestRewardModels(unittest.TestCase): + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + def assert_close_reward_scores( self, convs, @@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase): torch_dtype=torch_dtype, model_type="reward", ) as srt_runner: - srt_outputs = srt_runner.forward(convs) + prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False) + srt_outputs = srt_runner.forward(prompts) hf_scores = torch.tensor(hf_outputs.scores) srt_scores = torch.tensor(srt_outputs.scores) - print(hf_scores) - print(srt_scores) + print(f"{hf_scores=}") + print(f"{srt_scores=}") assert torch.all( abs(hf_scores - srt_scores) < tolerance @@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase): if __name__ == "__main__": - try: - mp.set_start_method("spawn") - except RuntimeError: - pass - unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ffc8e84f2..f7277f03d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -8,7 +8,7 @@ suites = { "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", - # "models/test_reward_models.py", + "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py", "test_double_sparsity.py",