diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f8b50ad5d..5784a0975 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -1,4 +1,4 @@ -name: Pull Request Test +name: PR Test on: push: diff --git a/README.md b/README.md index 0f1cf3838..edde172a2 100644 --- a/README.md +++ b/README.md @@ -205,7 +205,7 @@ It supports streaming, vision, and most features of the Chat/Completions/Models/ ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2 ``` -- Add `--dp 2` to enable multi-GPU data parallelism. It can also be used together with tensor parallelism. Data parallelism is better for throughput if there is enough memory. +- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2 ``` diff --git a/docs/en/custom_chat_template.md b/docs/en/custom_chat_template.md index 815c7e676..3760bbc6a 100644 --- a/docs/en/custom_chat_template.md +++ b/docs/en/custom_chat_template.md @@ -1,6 +1,9 @@ # Custom Chat Template in SGLang Runtime -By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. +**NOTE**: There are two chat template systems in SGLang project. This document is about setting a custom chat template for the OpenAI-compatible API server (defined at [conversation.py](../../python/sglang/srt/conversation.py)). It is NOT related to the chat template used in the SGLang language frontend (defined at [chat_template.py](../../python/sglang/lang/chat_template.py)). + +By default, the server uses the chat template specified in the model tokenizer from Hugging Face. +It should just work for most official models such as Llama-2/Llama-3. If needed, you can also override the chat template when launching the server: diff --git a/examples/frontend_language/quick_start/local_example_llava_next.py b/examples/frontend_language/quick_start/local_example_llava_next.py index 823dc7b0e..fc5a1d04c 100644 --- a/examples/frontend_language/quick_start/local_example_llava_next.py +++ b/examples/frontend_language/quick_start/local_example_llava_next.py @@ -2,13 +2,8 @@ Usage: python3 local_example_llava_next.py """ -from PIL import ImageFile - import sglang as sgl from sglang.lang.chat_template import get_chat_template -from sglang.srt.utils import load_image - -ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images @sgl.function diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 5012f646e..344b51d2d 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -4,7 +4,7 @@ from typing import List, Optional from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend -from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import ( @@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend): base_url: str, api_key: Optional[str] = None, verify: Optional[str] = None, + chat_template_name: Optional[str] = None, ): super().__init__() self.support_concate_and_append = True @@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend): self._assert_success(res) self.model_info = res.json() - self.chat_template = get_chat_template_by_model_path( - self.model_info["model_path"] - ) + if chat_template_name: + self.chat_template = get_chat_template(chat_template_name) + else: + self.chat_template = get_chat_template_by_model_path( + self.model_info["model_path"] + ) def get_model_name(self): return self.model_info["model_path"] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 644670a2b..6af820641 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -86,8 +86,8 @@ class TokenizerManager: self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") - self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") + self.send_to_controller = context.socket(zmq.PUSH) + self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}") # Read model args self.model_path = server_args.model_path @@ -271,7 +271,7 @@ class TokenizerManager: input_ids, sampling_params, ) - self.send_to_router.send_pyobj(tokenized_obj) + self.send_to_controller.send_pyobj(tokenized_obj) # Recv results event = asyncio.Event() @@ -367,7 +367,7 @@ class TokenizerManager: input_ids, sampling_params, ) - self.send_to_router.send_pyobj(tokenized_obj) + self.send_to_controller.send_pyobj(tokenized_obj) event = asyncio.Event() state = ReqState([], False, event) @@ -500,14 +500,14 @@ class TokenizerManager: def flush_cache(self): req = FlushCacheReq() - self.send_to_router.send_pyobj(req) + self.send_to_controller.send_pyobj(req) def abort_request(self, rid: str): if rid not in self.rid_to_state: return del self.rid_to_state[rid] req = AbortReq(rid) - self.send_to_router.send_pyobj(req) + self.send_to_controller.send_pyobj(req) async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None @@ -524,7 +524,7 @@ class TokenizerManager: # wait for the previous generation requests to finish while len(self.rid_to_state) > 0: await asyncio.sleep(0) - self.send_to_router.send_pyobj(obj) + self.send_to_controller.send_pyobj(obj) self.model_update_result = asyncio.Future() result = await self.model_update_result if result.success: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 26afe6600..09b3c7127 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -606,16 +606,6 @@ def import_model_classes(): assert entry.__name__ not in model_arch_name_to_cls model_arch_name_to_cls[entry.__name__] = entry - # compat: some models such as chatglm has incorrect class set in config.json - # usage: [ tuple("From_Entry_Class_Name": EntryClass), ] - if hasattr(module, "EntryClassRemapping") and isinstance( - module.EntryClassRemapping, list - ): - for remap in module.EntryClassRemapping: - if isinstance(remap, tuple) and len(remap) == 2: - assert remap[0] not in model_arch_name_to_cls - model_arch_name_to_cls[remap[0]] = remap[1] - return model_arch_name_to_cls diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 9eb04dc26..94b405f8e 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -402,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module): weight_loader(param, loaded_weight) -EntryClass = ChatGLMForCausalLM -# compat: glm model.config class == ChatGLMModel -EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)] +class ChatGLMModel(ChatGLMForCausalLM): + pass + + +EntryClass = [ChatGLMForCausalLM, ChatGLMModel] diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 4dcafed7c..9cddcb34f 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module): config, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, - efficient_weight_load=False, ) -> None: super().__init__() self.config = config @@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) return len(params_dict) - def load_weights( - self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module): ] params_dict = dict(self.named_parameters()) - def load_weights_per_param(name, loaded_weight): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: return if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: @@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module): if name.startswith("model.vision_tower") and name not in params_dict: return + name = name.replace("attn.attention", "self_attn") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if name is None or loaded_weight is None: - for name, loaded_weight in weights: - name = name.replace("attn.attention", "self_attn") - load_weights_per_param(name, loaded_weight) - else: - name = name.replace("attn.attention", "self_attn") - load_weights_per_param(name, loaded_weight) - EntryClass = ExaoneForCausalLM diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama.py similarity index 91% rename from python/sglang/srt/models/llama2.py rename to python/sglang/srt/models/llama.py index 22751d9b6..43c7cd54a 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama.py @@ -295,7 +295,6 @@ class LlamaForCausalLM(nn.Module): config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, - efficient_weight_load=False, ) -> None: super().__init__() self.config = config @@ -305,6 +304,8 @@ class LlamaForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.sampler = Sampler() + self.param_dict = dict(self.named_parameters()) + @torch.no_grad() def forward( self, @@ -320,30 +321,7 @@ class LlamaForCausalLM(nn.Module): sample_output = self.sampler(logits_output, input_metadata.sampling_info) return sample_output, logits_output - def get_module_name(self, name): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id, num_shard) - ("qkv_proj", "q_proj", "q", 3), - ("qkv_proj", "k_proj", "k", 3), - ("qkv_proj", "v_proj", "v", 3), - ("gate_up_proj", "gate_proj", 0, 2), - ("gate_up_proj", "up_proj", 1, 2), - ] - for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: - if weight_name in name: - return ( - name.replace(weight_name, param_name)[: -len(".weight")], - num_shard, - ) - return name[: -len(".weight")], 1 - - def get_num_params(self): - params_dict = dict(self.named_parameters()) - return len(params_dict) - - def load_weights( - self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -352,9 +330,9 @@ class LlamaForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - params_dict = dict(self.named_parameters()) + params_dict = self.param_dict - def load_weights_per_param(name, loaded_weight): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: return if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: @@ -383,11 +361,5 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if name is None or loaded_weight is None: - for name, loaded_weight in weights: - load_weights_per_param(name, loaded_weight) - else: - load_weights_per_param(name, loaded_weight) - EntryClass = LlamaForCausalLM diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 03ab5e802..db424ff18 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -16,17 +16,16 @@ limitations under the License. from typing import Iterable, Optional, Tuple import torch -import tqdm from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.model_executor.forward_batch_info import InputMetadata -from sglang.srt.models.llama2 import LlamaModel +from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel class LlamaForClassification(nn.Module): @@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module): self.model = LlamaModel(config, quant_config=quant_config) self.classification_head = nn.Linear( - config.hidden_size, config.classification_out_size + config.hidden_size, config.classification_out_size, bias=False ) self.eos_token_id = config.eos_token_id + self.param_dict = dict(self.named_parameters()) + @torch.no_grad() def forward( self, @@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module): (input_metadata.batch_size, self.config.classification_out_size) ).to(input_ids.device) - return LogitsProcessorOutput( + logits_output = LogitsProcessorOutput( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, @@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module): output_top_logprobs=None, ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name or "projector" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "lm_head" in name: - continue + # A dummy to make this work + sample_output = SampleOutput( + success=torch.full( + size=(scores.shape[0],), + fill_value=True, + dtype=torch.bool, + ), + probs=torch.full( + size=(scores.shape[0], 1), + fill_value=1.0, + dtype=torch.float16, + ), + batch_next_token_ids=torch.full( + size=(scores.shape[0],), + fill_value=0, + dtype=torch.long, + ), + ) + return sample_output, logits_output - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = self.param_dict + + for name, loaded_weight in weights: + if "classification_head" in name: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + elif "lm_head" in name: + continue + else: + LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) EntryClass = LlamaForClassification diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index e4e9174f1..fe407b29f 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Tuple +from typing import Iterable, Tuple import torch from torch import nn @@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.model_executor.model_runner import InputMetadata -from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel +from sglang.srt.models.llama import LlamaModel class LlamaEmbeddingModel(nn.Module): @@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module): config: LlamaConfig, quant_config=None, cache_config=None, - efficient_weight_load=False, ) -> None: super().__init__() self.model = LlamaModel(config, quant_config=quant_config) @@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module): load_weights_per_param(name, loaded_weight) -EntryClass = LlamaEmbeddingModel -# compat: e5-mistral model.config class == MistralModel -EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)] +class MistralModel(LlamaEmbeddingModel): + pass + + +EntryClass = [LlamaEmbeddingModel, MistralModel] diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 7dcf5348b..2e3c9ceba 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -41,7 +41,7 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM @@ -395,21 +395,19 @@ class LlavaBaseForCausalLM(nn.Module): "model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + "model.image_newline": "language_model.model.image_newline", } params_dict = dict(self.named_parameters()) - weights = list(weights) for name, loaded_weight in weights: - # FIXME: why projector weights read two times? - if "projector" in name or "vision_tower" in name: + if "projector" in name or "vision_tower" in name or "image_newline" in name: for weight_name, param_name in projector_weights.items(): if weight_name in name: name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - # load language model - self.language_model.load_weights(weights) + else: + self.language_model.load_weights([(name, loaded_weight)]) @property def num_patches_per_side(self): @@ -429,6 +427,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): self.vision_tower = None self.config.vision_config.hidden_size = config.mm_hidden_size self.config.text_config.hidden_size = config.hidden_size + self.multi_modal_projector = LlavaMultiModalProjector(config) self.language_model = LlamaForCausalLM(config, quant_config=quant_config) if "unpad" in getattr(config, "mm_patch_merge_type", ""): @@ -448,9 +447,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): self.config = config self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) - if getattr(self.config, "text_config", None) is None: self.config.text_config = Qwen2Config(self.config._name_or_path) @@ -459,7 +458,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): if getattr(self.config, "projector_hidden_act", None) is None: self.config.projector_hidden_act = "gelu" - if getattr(self.config, "image_token_index", None) is None: self.config.image_token_index = 151646 @@ -482,9 +480,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): self.config = config self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) - if getattr(self.config, "text_config", None) is None: self.config.text_config = MistralConfig(self.config._name_or_path) @@ -493,7 +491,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): if getattr(self.config, "projector_hidden_act", None) is None: self.config.projector_hidden_act = "gelu" - if getattr(self.config, "image_token_index", None) is None: self.config.image_token_index = 32000 diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 44e400ff6..f268ecbbc 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.llama import LlamaForCausalLM class LlavaVidForCausalLM(nn.Module): @@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module): "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1", "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + "model.image_newline": "language_model.model.image_newline", } params_dict = dict(self.named_parameters()) - weights = list(weights) for name, loaded_weight in weights: # FIXME: why projector weights read two times? - if "projector" in name or "vision_tower" in name: + if "projector" in name or "vision_tower" in name or "image_newline" in name: for weight_name, param_name in projector_weights.items(): if weight_name in name: name = name.replace(weight_name, param_name) @@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module): continue weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - # load language model - self.language_model.load_weights(weights) + else: + self.language_model.load_weights([(name, loaded_weight)]) @property def num_patches_per_side(self): diff --git a/python/sglang/srt/models/mistral.py b/python/sglang/srt/models/mistral.py index 614c1c1d7..1430ece43 100644 --- a/python/sglang/srt/models/mistral.py +++ b/python/sglang/srt/models/mistral.py @@ -15,12 +15,11 @@ limitations under the License. """Inference-only Mistral model.""" -from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.llama import LlamaForCausalLM class MistralForCausalLM(LlamaForCausalLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + pass EntryClass = MistralForCausalLM diff --git a/scripts/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py index cafbd19fd..dbcafb88d 100644 --- a/scripts/deprecated/test_httpserver_classify.py +++ b/scripts/deprecated/test_httpserver_classify.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m sglang.launch_server --model-path /model/llama-classification +python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification python3 test_httpserver_classify.py """ diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index d2d311610..95aeddb9a 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -3,23 +3,24 @@ Usage: python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4 Reference output: +========== Prompt 0 ========== +prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141], + device='cuda:0') The capital of France is Paris. The capital of the United States is Washington, D.C. -The capital of Canada is Ottawa. -The capital of Japan is Tokyo -prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141], + +========== Prompt 1 ========== +prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742], device='cuda:0') The capital of the United Kindom is London. The capital of the United Kingdom is London. -The capital of the United Kingdom is London. -The capital of the United Kingdom is London. -prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742], +The capital of + +========== Prompt 2 ========== +prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609], device='cuda:0') Today is a sunny day and I like to go for a walk in the park. -I'm going to the park to play in the grass and water. -Today is a very -prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609], - device='cuda:0') +I'm going to the """ import argparse @@ -47,7 +48,7 @@ def normal_text(args): ] max_new_tokens = 16 - for p in prompts: + for i, p in enumerate(prompts): if isinstance(p, str): input_ids = t.encode(p, return_tensors="pt").cuda() else: @@ -60,7 +61,8 @@ def normal_text(args): prefill_logits = m.forward(input_ids).logits[0][-1] - print("prefill logits", prefill_logits) + print(f"\n========== Prompt {i} ==========") + print("prefill logits (final)", prefill_logits) print(output_str)