diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 8aeac1ae4..375e24cd4 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -45,6 +45,7 @@ in the GitHub search bar. | **SmolLM** (135M–1.7B) | `HuggingFaceTB/SmolLM-1.7B` | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. | | **GLM-4** (Multilingual 9B) | `ZhipuAI/glm-4-9b-chat` | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). | | **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. | +| **Arcee AFM-4.5B** | `arcee-ai/AFM-4.5B-Base` | Arcee's foundational model series for real world reliability and edge deployments. | | **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | -| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | +| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | diff --git a/python/sglang/srt/models/arcee.py b/python/sglang/srt/models/arcee.py new file mode 100644 index 000000000..f9ebfe19a --- /dev/null +++ b/python/sglang/srt/models/arcee.py @@ -0,0 +1,532 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights.""" + +import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import LlamaConfig + +from sglang.srt.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers + +logger = logging.getLogger(__name__) + + +class ArceeMLP(nn.Module): + """ + MLP block for the Arcee model, using a ReLU-squared activation function. + This differs from the Llama SwiGLU activation. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + # Arcee uses a single up-projection, not a merged gate/up projection. + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + reduce_results=reduce_results, + ) + if hidden_act != "relu2": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Arcee model in SGLang only supports 'relu2'." + ) + # The activation function is relu(x)^2 + self.act_fn = get_act_fn("relu2") + + def forward(self, x, forward_batch=None): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class ArceeAttention(nn.Module): + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + bias: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.rotary_dim = int(self.partial_rotary_factor * self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class ArceeDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + self.self_attn = ArceeAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + bias=attention_bias, + ) + self.mlp = ArceeMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ArceeModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: ArceeDecoderLayer( + config=config, quant_config=quant_config, layer_id=idx, prefix=prefix + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix="model.layers", + ) + + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) + self.layers_to_capture = [] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + aux_hidden_states = [] + for i in range(self.start_layer, self.end_layer): + if i in self.layers_to_capture: + aux_hidden_states.append(hidden_states + residual) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling factor attribute!" + ) + + +class ArceeForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + # Note: gate_proj is removed compared to Llama + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + # Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP + ".q_proj": (".qkv_proj", 0), + ".k_proj": (".qkv_proj", 1), + ".v_proj": (".qkv_proj", 2), + } + + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.pp_group = get_pp_group() + self.config = config + self.quant_config = quant_config + self.model = self._init_model(config, quant_config, add_prefix("model", prefix)) + # Arcee does not tie word embeddings + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # Parameters that are stacked in a single tensor in this model + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + self.capture_aux_hidden_states = False + + def _init_model( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + return ArceeModel(config, quant_config=quant_config, prefix=prefix) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> LogitsProcessorOutput: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + ) + else: + return self.pooler(hidden_states, forward_batch) + else: + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + + # Handle FP8 kv-scale remapping + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + is_stacked = False + for param_name, weight_name, shard_id in self.stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + is_stacked = True + break + + if not is_stacked: + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in model.") + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + +EntryClass = [ArceeForCausalLM]