Support gpt-bigcode model class (#681)
This commit is contained in:
@@ -34,12 +34,11 @@ class LogitProcessorOutput:
|
||||
@dataclasses.dataclass
|
||||
class LogitsMetadata:
|
||||
forward_mode: ForwardMode
|
||||
extend_seq_lens: torch.Tensor
|
||||
extend_start_loc: torch.Tensor
|
||||
|
||||
# For logprobs
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: List[int]
|
||||
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
@@ -16,8 +17,28 @@ from sglang.srt.managers.controller.infer_batch import (
|
||||
)
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse=False):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
if reverse:
|
||||
sub._forward_method = sub.forward_cuda
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse)
|
||||
|
||||
|
||||
def get_forward(model: torch.nn.Module, use_torch: bool):
|
||||
if use_torch:
|
||||
_to_torch(model, reverse=False)
|
||||
return torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
else:
|
||||
_to_torch(model, reverse=True)
|
||||
return model.forward
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
def __init__(self, model_runner, max_batch_size_to_capture):
|
||||
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
||||
self.model_runner = model_runner
|
||||
self.graphs = {}
|
||||
self.input_buffers = {}
|
||||
@@ -55,6 +76,8 @@ class CudaGraphRunner:
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||
|
||||
def can_run(self, batch_size):
|
||||
return batch_size < self.max_bs
|
||||
|
||||
@@ -63,18 +86,19 @@ class CudaGraphRunner:
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
for bs in batch_size_list:
|
||||
forward = get_forward(self.model_runner.model, bs in self.compile_bs)
|
||||
(
|
||||
graph,
|
||||
input_buffers,
|
||||
output_buffers,
|
||||
flashinfer_handler,
|
||||
) = self.capture_one_batch_size(bs)
|
||||
) = self.capture_one_batch_size(bs, forward)
|
||||
self.graphs[bs] = graph
|
||||
self.input_buffers[bs] = input_buffers
|
||||
self.output_buffers[bs] = output_buffers
|
||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
||||
|
||||
def capture_one_batch_size(self, bs):
|
||||
def capture_one_batch_size(self, bs, forward):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
|
||||
@@ -127,9 +151,8 @@ class CudaGraphRunner:
|
||||
skip_flashinfer_init=True,
|
||||
)
|
||||
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
||||
return self.model_runner.model.forward(
|
||||
input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
for _ in range(2):
|
||||
run_once()
|
||||
|
||||
@@ -244,7 +244,9 @@ class ModelRunner:
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self, max_batch_size_to_capture=max(batch_size_list)
|
||||
self,
|
||||
max_batch_size_to_capture=max(batch_size_list),
|
||||
use_torch_compile=self.server_args.enable_torch_compile,
|
||||
)
|
||||
try:
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
|
||||
282
python/sglang/srt/models/gpt_bigcode.py
Normal file
282
python/sglang/srt/models/gpt_bigcode.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.infer_batch import InputMetadata
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert total_num_heads % self.tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // self.tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.multi_query = config.multi_query
|
||||
if self.multi_query:
|
||||
total_num_kv_heads = 1
|
||||
self.num_kv_heads = 1
|
||||
else:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.kv_dim = self.head_dim * self.num_kv_heads
|
||||
self.c_attn = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
scaling=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
[
|
||||
self.hidden_size // self.tensor_model_parallel_world_size,
|
||||
self.kv_dim,
|
||||
self.kv_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class GPTBigMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn(
|
||||
config.activation_function, quant_config, intermediate_size
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states, input_metadata=input_metadata
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.wte = VocabParallelEmbedding(
|
||||
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
||||
)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPTBigCodeBlock(i, config, cache_config, quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, input_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
|
||||
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
|
||||
|
||||
embedding_modules = {
|
||||
"wte": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(
|
||||
config, cache_config, quant_config, lora_config
|
||||
)
|
||||
self.lm_head = self.transformer.wte
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
||||
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
||||
weight_loader(param, loaded_weight, "q")
|
||||
weight_loader(param, loaded_weight, "k")
|
||||
weight_loader(param, loaded_weight, "v")
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = GPTBigCodeForCausalLM
|
||||
@@ -157,6 +157,19 @@ def _set_global_server_args(server_args: ServerArgs):
|
||||
}
|
||||
|
||||
|
||||
def _set_torch_compile_config():
|
||||
# The following configurations are for torch compile optimizations
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||
|
||||
# FIXME: tmp workaround
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 128
|
||||
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
@@ -190,6 +203,10 @@ def launch_server(
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
if server_args.enable_torch_compile:
|
||||
_set_torch_compile_config()
|
||||
|
||||
_set_global_server_args(server_args)
|
||||
|
||||
# Allocate ports
|
||||
|
||||
@@ -55,6 +55,7 @@ class ServerArgs:
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_cuda_graph: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
efficient_weight_load: bool = False
|
||||
@@ -317,6 +318,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-compile",
|
||||
action="store_true",
|
||||
help="Optimize the model with torch.compile, experimental feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user