diff --git a/README.md b/README.md index a89cbe232..b388df8d9 100644 --- a/README.md +++ b/README.md @@ -369,8 +369,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - Mistral - Mixtral - Qwen / Qwen 2 +- Gemma + - Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors. + - `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32` - LLaVA - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` + - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` + - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000` - Yi-VL - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). - AWQ/GPTQ quantization diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 5c39a61b1..eb9dc3264 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -21,7 +21,9 @@ class RuntimeEndpoint(BaseBackend): self.verify = verify res = http_request( - self.base_url + "/get_model_info", auth_token=self.auth_token, verify=self.verify + self.base_url + "/get_model_info", + auth_token=self.auth_token, + verify=self.verify, ) assert res.status_code == 200 self.model_info = res.json() @@ -41,7 +43,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, auth_token=self.auth_token, - verify=self.verify + verify=self.verify, ) assert res.status_code == 200 @@ -50,7 +52,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, auth_token=self.auth_token, - verify=self.verify + verify=self.verify, ) assert res.status_code == 200 @@ -58,7 +60,10 @@ class RuntimeEndpoint(BaseBackend): data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} self._add_images(s, data) res = http_request( - self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify + self.base_url + "/generate", + json=data, + auth_token=self.auth_token, + verify=self.verify, ) assert res.status_code == 200 @@ -90,7 +95,10 @@ class RuntimeEndpoint(BaseBackend): self._add_images(s, data) res = http_request( - self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify + self.base_url + "/generate", + json=data, + auth_token=self.auth_token, + verify=self.verify, ) obj = res.json() comp = obj["text"] @@ -129,7 +137,7 @@ class RuntimeEndpoint(BaseBackend): json=data, stream=True, auth_token=self.auth_token, - verify=self.verify + verify=self.verify, ) pos = 0 @@ -161,7 +169,10 @@ class RuntimeEndpoint(BaseBackend): data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} self._add_images(s, data) res = http_request( - self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify + self.base_url + "/generate", + json=data, + auth_token=self.auth_token, + verify=self.verify, ) assert res.status_code == 200 prompt_len = res.json()["meta_info"]["prompt_tokens"] @@ -175,7 +186,10 @@ class RuntimeEndpoint(BaseBackend): } self._add_images(s, data) res = http_request( - self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify + self.base_url + "/generate", + json=data, + auth_token=self.auth_token, + verify=self.verify, ) assert res.status_code == 200 obj = res.json() @@ -192,7 +206,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/concate_and_append_request", json={"src_rids": src_rids, "dst_rid": dst_rid}, auth_token=self.auth_token, - verify=self.verify + verify=self.verify, ) assert res.status_code == 200 diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index c8a80fd32..fc316b6e7 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -4,8 +4,8 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import wrap_kernel_launcher from sglang.srt.managers.router.model_runner import global_server_args +from sglang.srt.utils import wrap_kernel_launcher if global_server_args.attention_reduce_in_fp32: REDUCE_TRITON_TYPE = tl.float32 diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index e2f4492dd..4030c5cd7 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -7,7 +7,7 @@ import torch from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from torch import nn -from transformers import GemmaConfig +from transformers import PretrainedConfig from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import GeluAndMul @@ -136,7 +136,7 @@ class GemmaAttention(nn.Module): class GemmaDecoderLayer(nn.Module): def __init__( self, - config: GemmaConfig, + config: PretrainedConfig, layer_id: int = 0, linear_method: Optional[LinearMethodBase] = None, ) -> None: @@ -190,7 +190,7 @@ class GemmaDecoderLayer(nn.Module): class GemmaModel(nn.Module): def __init__( self, - config: GemmaConfig, + config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() @@ -213,12 +213,12 @@ class GemmaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - if not skip_embed: + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: - hidden_states = input_ids + hidden_states = input_embeds # Normalize the embedding by sqrt(hidden_size) hidden_states *= self.config.hidden_size**0.5 @@ -262,7 +262,7 @@ class GemmaForCausalLM(nn.Module): def __init__( self, - config: GemmaConfig, + config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -279,9 +279,9 @@ class GemmaForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, input_metadata: InputMetadata, - skip_embed: bool = False, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index b615f5953..8e42d48c7 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -233,9 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module): input_ids, positions, input_metadata, input_embeds=input_embeds ) elif input_metadata.forward_mode == ForwardMode.DECODE: - return self.language_model( - input_ids, positions, input_metadata - ) + return self.language_model(input_ids, positions, input_metadata) def load_weights( self, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9dc6e9fd4..2e604c2c7 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -550,6 +550,7 @@ class Runtime: tp_size: int = 1, model_mode: List[str] = (), schedule_heuristic: str = "lpm", + attention_reduce_in_fp32: bool = False, random_seed: int = 42, log_level: str = "error", port: Optional[int] = None, @@ -572,6 +573,7 @@ class Runtime: tp_size=tp_size, model_mode=model_mode, schedule_heuristic=schedule_heuristic, + attention_reduce_in_fp32=attention_reduce_in_fp32, random_seed=random_seed, log_level=log_level, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 95ce19087..596b584ab 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -21,6 +21,7 @@ class ServerArgs: model_mode: List[str] = () schedule_heuristic: str = "lpm" schedule_conservativeness: float = 1.0 + attention_reduce_in_fp32: bool = False random_seed: int = 42 stream_interval: int = 8 disable_log_stats: bool = False @@ -28,7 +29,6 @@ class ServerArgs: log_level: str = "info" disable_regex_jump_forward: bool = False disable_disk_cache: bool = False - attention_reduce_in_fp32: bool = False def __post_init__(self): if self.tokenizer_path is None: @@ -157,6 +157,11 @@ class ServerArgs: default=ServerArgs.random_seed, help="Random seed.", ) + parser.add_argument( + "--attention-reduce-in-fp32", + action="store_true", + help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.", + ) parser.add_argument( "--stream-interval", type=int, @@ -190,11 +195,6 @@ class ServerArgs: action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) - parser.add_argument( - "--attention-reduce-in-fp32", - action="store_true", - help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 4bcb0fb37..143d4b8f7 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -97,7 +97,9 @@ def http_request(url, json=None, stream=False, auth_token=None, verify=None): "Content-Type": "application/json", "Authentication": f"Bearer {auth_token}", } - return requests.post(url, json=json, stream=True, headers=headers, verify=verify) + return requests.post( + url, json=json, stream=True, headers=headers, verify=verify + ) else: req = urllib.request.Request(url) req.add_header("Content-Type", "application/json; charset=utf-8") diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py index cd7c94c27..3e3534709 100644 --- a/test/srt/model/bench_llama_low_api.py +++ b/test/srt/model/bench_llama_low_api.py @@ -66,9 +66,9 @@ class BenchBatch: p_idx = prefix_req_idx[i // fork_num].item() n_idx = self.req_pool_indices[i].item() req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] - req_to_token[n_idx, prefix_len : prefix_len + extend_len] = ( - self.out_cache_loc[i * extend_len : (i + 1) * extend_len] - ) + req_to_token[ + n_idx, prefix_len : prefix_len + extend_len + ] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len] def update_decode(self, predict_ids, batch_size): assert predict_ids.shape[0] == batch_size @@ -81,9 +81,9 @@ class BenchBatch: self.out_cache_cont_start, self.out_cache_cont_end, ) = self.token_to_kv_pool.alloc_contiguous(batch_size) - self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( - self.out_cache_loc - ) + self.req_to_token_pool.req_to_token[ + self.req_pool_indices, self.seq_lens + ] = self.out_cache_loc self.seq_lens.add_(1)