Improve gemma and documentations (#278)
This commit is contained in:
@@ -369,8 +369,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
||||
- Mistral
|
||||
- Mixtral
|
||||
- Qwen / Qwen 2
|
||||
- Gemma
|
||||
- Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors.
|
||||
- `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
|
||||
- LLaVA
|
||||
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
||||
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
||||
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000`
|
||||
- Yi-VL
|
||||
- see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
|
||||
- AWQ/GPTQ quantization
|
||||
|
||||
@@ -21,7 +21,9 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.verify = verify
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/get_model_info", auth_token=self.auth_token, verify=self.verify
|
||||
self.base_url + "/get_model_info",
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
@@ -41,7 +43,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -50,7 +52,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/generate",
|
||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -58,7 +60,10 @@ class RuntimeEndpoint(BaseBackend):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -90,7 +95,10 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
@@ -129,7 +137,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
json=data,
|
||||
stream=True,
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify
|
||||
verify=self.verify,
|
||||
)
|
||||
pos = 0
|
||||
|
||||
@@ -161,7 +169,10 @@ class RuntimeEndpoint(BaseBackend):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
@@ -175,7 +186,10 @@ class RuntimeEndpoint(BaseBackend):
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
obj = res.json()
|
||||
@@ -192,7 +206,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
from sglang.srt.managers.router.model_runner import global_server_args
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
if global_server_args.attention_reduce_in_fp32:
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from torch import nn
|
||||
from transformers import GemmaConfig
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
@@ -136,7 +136,7 @@ class GemmaAttention(nn.Module):
|
||||
class GemmaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
@@ -190,7 +190,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -213,12 +213,12 @@ class GemmaModel(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if not skip_embed:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
hidden_states = input_embeds
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
hidden_states *= self.config.hidden_size**0.5
|
||||
@@ -262,7 +262,7 @@ class GemmaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
@@ -279,9 +279,9 @@ class GemmaForCausalLM(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
|
||||
@@ -233,9 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata
|
||||
)
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
|
||||
@@ -550,6 +550,7 @@ class Runtime:
|
||||
tp_size: int = 1,
|
||||
model_mode: List[str] = (),
|
||||
schedule_heuristic: str = "lpm",
|
||||
attention_reduce_in_fp32: bool = False,
|
||||
random_seed: int = 42,
|
||||
log_level: str = "error",
|
||||
port: Optional[int] = None,
|
||||
@@ -572,6 +573,7 @@ class Runtime:
|
||||
tp_size=tp_size,
|
||||
model_mode=model_mode,
|
||||
schedule_heuristic=schedule_heuristic,
|
||||
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ class ServerArgs:
|
||||
model_mode: List[str] = ()
|
||||
schedule_heuristic: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
attention_reduce_in_fp32: bool = False
|
||||
random_seed: int = 42
|
||||
stream_interval: int = 8
|
||||
disable_log_stats: bool = False
|
||||
@@ -28,7 +29,6 @@ class ServerArgs:
|
||||
log_level: str = "info"
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
@@ -157,6 +157,11 @@ class ServerArgs:
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream-interval",
|
||||
type=int,
|
||||
@@ -190,11 +195,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -97,7 +97,9 @@ def http_request(url, json=None, stream=False, auth_token=None, verify=None):
|
||||
"Content-Type": "application/json",
|
||||
"Authentication": f"Bearer {auth_token}",
|
||||
}
|
||||
return requests.post(url, json=json, stream=True, headers=headers, verify=verify)
|
||||
return requests.post(
|
||||
url, json=json, stream=True, headers=headers, verify=verify
|
||||
)
|
||||
else:
|
||||
req = urllib.request.Request(url)
|
||||
req.add_header("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
@@ -66,9 +66,9 @@ class BenchBatch:
|
||||
p_idx = prefix_req_idx[i // fork_num].item()
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
|
||||
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||
)
|
||||
req_to_token[
|
||||
n_idx, prefix_len : prefix_len + extend_len
|
||||
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||
|
||||
def update_decode(self, predict_ids, batch_size):
|
||||
assert predict_ids.shape[0] == batch_size
|
||||
@@ -81,9 +81,9 @@ class BenchBatch:
|
||||
self.out_cache_cont_start,
|
||||
self.out_cache_cont_end,
|
||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||
self.out_cache_loc
|
||||
)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens
|
||||
] = self.out_cache_loc
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user