From 5641a0945831bddefe56a99d9602da36d20064f8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 25 Apr 2025 15:50:28 -0700 Subject: [PATCH] Revert "[Model] Support `ArcticForCausalLM` architecture (Snowflake/snowflake-arctic-instruct)" (#5754) --- docs/supported_models/generative_models.md | 1 - python/sglang/srt/configs/__init__.py | 2 - python/sglang/srt/configs/arctic.py | 127 ----- python/sglang/srt/hf_transformers_utils.py | 2 - python/sglang/srt/models/arctic.py | 634 --------------------- 5 files changed, 766 deletions(-) delete mode 100644 python/sglang/srt/configs/arctic.py delete mode 100644 python/sglang/srt/models/arctic.py diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 5c9f47cb9..486839ae9 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -28,7 +28,6 @@ python3 -m sglang.launch_server \ | **Command-R** (Cohere) | `CohereForAI/c4ai-command-r-v01` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. | | **DBRX** (Databricks) | `databricks/dbrx-instruct` | Databricks’ 132B-parameter MoE model (36B active) trained on 12T tokens; competes with GPT-3.5 quality as a fully open foundation model. | | **Grok** (xAI) | `xai-org/grok-1` | xAI’s grok-1 model known for vast size(314B parameters) and high quality; integrated in SGLang for high-performance inference. | -| **Arctic** (Snowflake) | `Snowflake/snowflake-arctic-instruct` | Snowflake’s dense-MoE model (17B active, 480B total) with top-2 routing, built for enterprise-grade reasoning, code, and instruction tasks. | | **ChatGLM** (GLM-130B family) | `THUDM/chatglm2-6b` | Zhipu AI’s bilingual chat model (6B) excelling at Chinese-English dialogue; fine-tuned for conversational quality and alignment. | | **InternLM 2** (7B, 20B) | `internlm/internlm2-7b` | Next-gen InternLM (7B and 20B) from SenseTime, offering strong reasoning and ultra-long context support (up to 200K tokens). | | **ExaONE 3** (Korean-English) | `LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct` | LG AI Research’s Korean-English model (7.8B) trained on 8T tokens; provides high-quality bilingual understanding and generation. | diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 7e07fe3a6..1e8370ba7 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,4 +1,3 @@ -from sglang.srt.configs.arctic import ArcticConfig from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config @@ -6,7 +5,6 @@ from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.janus_pro import MultiModalityConfig __all__ = [ - "ArcticConfig", "ExaoneConfig", "ChatGLMConfig", "DbrxConfig", diff --git a/python/sglang/srt/configs/arctic.py b/python/sglang/srt/configs/arctic.py deleted file mode 100644 index ff3887373..000000000 --- a/python/sglang/srt/configs/arctic.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -"""Arctic model configuration""" - -from typing import Any, Dict, Optional - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", -} - - -class ArcticConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an - Arctic model according to the specified arguments, defining the model architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`ArcticModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 14336): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 4096): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). - pad_token_id (`int`, *optional*): - The id of the padding token. - bos_token_id (`int`, *optional*, defaults to 1): - The id of the "beginning-of-sequence" token. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the "end-of-sequence" token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. - sliding_window (`int`, *optional*): - Sliding window attention window size. If not specified, will default to `4096`. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - num_experts_per_tok (`int`, *optional*, defaults to 2): - The number of experts to root per-token, can be also interpreted as the `top-p` routing parameter - num_local_experts (`int`, *optional*, defaults to 8): - Number of experts per Sparse MLP layer. - moe_layer_frequency (`int`, *optional*, defaults to 2): - Frequency of MoE layers in the model. - """ - - model_type = "arctic" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - max_position_embeddings=4096, - initializer_range=0.02, - rms_norm_eps=1e-5, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - rope_theta=1e6, - sliding_window=None, - attention_dropout=0.0, - num_experts_per_tok=1, - num_local_experts=8, - moe_layer_frequency=2, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.sliding_window = sliding_window - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.num_experts_per_tok = num_experts_per_tok - self.num_local_experts = num_local_experts - self.moe_layer_frequency = moe_layer_frequency - - # For backward compatibility - self._attn_implementation = kwargs.pop("_attn_implementation", "eager") - self.use_residual = kwargs.pop("use_residual", True) - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 9c6ecb3e3..0a189a7bf 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -31,7 +31,6 @@ from transformers import ( from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from sglang.srt.configs import ( - ArcticConfig, ChatGLMConfig, DbrxConfig, DeepseekVL2Config, @@ -42,7 +41,6 @@ from sglang.srt.connector import create_remote_connector from sglang.srt.utils import is_remote_url _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - ArcticConfig.model_type: ArcticConfig, ChatGLMConfig.model_type: ChatGLMConfig, DbrxConfig.model_type: DbrxConfig, ExaoneConfig.model_type: ExaoneConfig, diff --git a/python/sglang/srt/models/arctic.py b/python/sglang/srt/models/arctic.py deleted file mode 100644 index 3049b8a5c..000000000 --- a/python/sglang/srt/models/arctic.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# Adapted from -# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/arctic.py - -"""Inference-only Snowflake Arctic model.""" - -import logging -from typing import Iterable, List, Optional, Set, Tuple, Union - -import torch -from torch import nn - -from sglang.srt.configs.arctic import ArcticConfig -from sglang.srt.distributed import ( - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe import fused_experts, fused_topk -from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.utils import set_weight_attrs -from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.platforms import current_platform - -from .interfaces import SupportsPP, SupportsQuant -from .utils import ( - extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, -) - -logger = logging.getLogger(__name__) - - -class ArcticMLP(nn.Module): - def __init__( - self, - config: ArcticConfig, - expert_id: int = -1, - is_residual_mlp: bool = False, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", - ): - super().__init__() - self.hidden_size = config.hidden_size - self.expert_id = expert_id - - self.ffn_dim = ( - config.intermediate_size if not is_residual_mlp else self.hidden_size - ) - - self.w13 = MergedColumnParallelLinear( - self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config - ) - self.w2 = RowParallelLinear( - self.ffn_dim, - self.hidden_size, - bias=False, - reduce_results=reduce_results, - quant_config=quant_config, - ) - if config.hidden_act != "silu": - raise ValueError( - f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now." - ) - self.act_fn = SiluAndMul() - - def forward(self, hidden_states): - gate_up, _ = self.w13(hidden_states) - hidden_states = self.act_fn(gate_up) - hidden_states, _ = self.w2(hidden_states) - return hidden_states - - -class ArcticMoE(nn.Module): - """ - Model-parallel implementation of Arctic MoE Layer. - """ - - def __init__( - self, - config: ArcticConfig, - tp_size: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", - ): - super().__init__() - - layer_id = extract_layer_index(prefix) - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.hidden_size = config.hidden_size - self.num_experts = config.num_local_experts - self.layer_id = layer_id - self.top_k = config.num_experts_per_tok - self.intermediate_size = config.intermediate_size // self.tp_size - - self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 - self.is_quant = quant_config is not None - self.reduce_results = reduce_results - # Some other parameters - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - if not self.is_moe_layer: - self.mlp = ArcticMLP( - config, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.mlp", - ) - else: - self.gate = ReplicatedLinear( - self.hidden_size, - self.num_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - if self.is_quant: - raise NotImplementedError("Quantization is not supported yet.") - else: - self.ws = nn.Parameter( - torch.empty( - self.num_experts, - 2 * self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype, - ) - ) - self.w2s = nn.Parameter( - torch.empty( - self.num_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype, - ) - ) - set_weight_attrs( - self.ws, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2s, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - ): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - - def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - do_normalize = self.top_k > 1 - topk_weights, topk_ids = fused_topk( - hidden_states, router_logits, self.top_k, renormalize=do_normalize - ) - # topk_ids: (num_tokens, k) - if self.is_quant: - raise NotImplementedError("Quantization is not supported yet.") - # if 2 * num_tokens <= self.num_experts: - # # If much fewer tokens than experts, use selective dequantize. - # ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten()) - # w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten()) - # # We gathered the experts to the tokens so update the mapping. - # topk_ids = torch.arange( - # 0, - # topk_ids.numel(), - # device=topk_ids.device, - # ).reshape(topk_ids.shape) - # else: - # ws_dequantized = self.ws.ds_dequantize() - # w2s_dequantized = self.w2s.ds_dequantize() - - final_hidden_states = fused_experts( - hidden_states, - self.ws, - self.w2s, - topk_weights, - topk_ids, - inplace=True, - ) - if self.reduce_results and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states.view(num_tokens, hidden_size) - - def forward(self, hidden_states: torch.Tensor): - if self.is_moe_layer: - final_hidden_states = self.local_moe_fused(hidden_states) - else: - final_hidden_states = self.mlp(hidden_states) - return final_hidden_states - - -class ArcticAttention(nn.Module): - def __init__( - self, - config: ArcticConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - layer_idx = extract_layer_index(prefix) - - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: - assert self.total_num_kv_heads % tp_size == 0 - else: - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = self.hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, - bias=False, - reduce_results=True, - quant_config=quant_config, - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=int(self.rope_theta), - is_neox_style=True, - ) - - self.attn = RadixAttention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_idx, - prefix=f"{prefix}.attn", - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, forward_batch) - output, _ = self.o_proj(attn_output) - return output - - -class ArcticDecoderLayer(nn.Module): - def __init__( - self, - config: ArcticConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - layer_idx = extract_layer_index(prefix) - is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 - self.use_residual = config.use_residual and is_moe_layer - self.self_attn = ArcticAttention( - config, quant_config=quant_config, prefix=f"{prefix}.self_attn" - ) - self.block_sparse_moe = ArcticMoE( - config, - quant_config=quant_config, - reduce_results=(not self.use_residual), - prefix=f"{prefix}.block_sparse_moe", - ) - - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - if self.use_residual: - self.residual_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.residual_mlp = ArcticMLP( - config, - is_residual_mlp=True, - reduce_results=False, - prefix=f"{prefix}.residual_mlp", - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - residual_input = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) - hidden_states = residual_input + hidden_states - - residual_attn = hidden_states - if self.use_residual: - hidden_states = self.residual_layernorm(hidden_states) - hidden_states = self.residual_mlp(hidden_states) - residual_mlp = hidden_states - hidden_states = self.post_attention_layernorm(residual_input) - hidden_states = self.block_sparse_moe(hidden_states) - hidden_states = residual_mlp + hidden_states - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - hidden_states = residual_attn + hidden_states - else: - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.block_sparse_moe(hidden_states) - hidden_states = residual_attn + hidden_states - return hidden_states - - -class ArcticModel(nn.Module): - def __init__( - self, - *, - config: ArcticConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: ArcticDecoderLayer(config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers", - ) - self._attn_implementation = config._attn_implementation - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states"], config.hidden_size - ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if input_embeds is not None: - hidden_states = input_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - - for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states = layer(positions, hidden_states, forward_batch) - - hidden_states = self.norm(hidden_states) - return hidden_states - - -class ArcticForCausalLM(nn.Module): - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - - def __init__( - self, - *, - config: ArcticConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.supports_torch_tp = True - self.model = ArcticModel( - config=config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "model"), - ) - self.vocab_size = config.vocab_size - self.lm_head = ParallelLMHead( - self.vocab_size, - config.hidden_size, - quant_config=quant_config, - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.num_experts = config.num_local_experts - self.num_experts_per_tok = config.num_experts_per_tok - self.unpadded_vocab_size = config.vocab_size - self.logits_processor = LogitsProcessor(self.config) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - @torch.no_grad() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: Optional[torch.Tensor] = None, - ) -> LogitsProcessorOutput: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch - ) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - mlp_params_mapping: List[Tuple[str, str, int]] = [] - expert_params_mapping: List[Tuple[str, str, int]] = [] - num_layers = self.config.num_hidden_layers - - for layer in range(num_layers): - mlp_params_mapping.append( - ( - f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w1.weight", - 0, - ) - ) - mlp_params_mapping.append( - ( - f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w3.weight", - 1, - ) - ) - if (layer + 1) % self.config.moe_layer_frequency != 0: - # MLP layers - mlp_params_mapping.append( - ( - f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w1.weight", - 0, - ) - ) - mlp_params_mapping.append( - ( - f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w3.weight", - 1, - ) - ) - else: - # MoE layers - for expert_id in range(self.config.num_local_experts): - expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w1.weight", expert_id) - ) - expert_params_mapping.append( - ("w2s", f"experts.{expert_id}.w2.weight", expert_id) - ) - expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w3.weight", expert_id) - ) - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - - logger.info( - "It will take ~10 minutes loading from the 16-bit weights. " - "Alternatively, use the prequantized 8-bit weights of arctic " - "and set load-format to `sharded_state` will accelerate loading." - ) - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, shard_id in mlp_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, shard_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, loaded_weight, weight_name, expert_id=shard_id - ) - break - else: - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params