Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Inference-only DeepSeek NextN Speculative Decoding."""
|
|
import logging
|
|
from typing import Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
|
from sglang.srt.utils import BumpAllocator, add_prefix
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DeepseekModelNextN(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
|
prefix=add_prefix("embed_tokens", prefix),
|
|
)
|
|
|
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
|
|
|
self.decoder = DeepseekV2DecoderLayer(
|
|
config,
|
|
0,
|
|
quant_config=quant_config,
|
|
is_nextn=True,
|
|
prefix=add_prefix("decoder", prefix),
|
|
)
|
|
|
|
self.shared_head = nn.Module()
|
|
self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
|
|
zero_allocator = BumpAllocator(
|
|
buffer_size=2,
|
|
dtype=torch.float32,
|
|
device=(
|
|
input_embeds.device if input_embeds is not None else input_ids.device
|
|
),
|
|
)
|
|
|
|
if input_embeds is None:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
else:
|
|
hidden_states = input_embeds
|
|
|
|
if hidden_states.shape[0] > 0:
|
|
hidden_states = self.eh_proj(
|
|
torch.cat(
|
|
(
|
|
self.enorm(hidden_states),
|
|
self.hnorm(forward_batch.spec_info.hidden_states),
|
|
),
|
|
dim=-1,
|
|
)
|
|
)
|
|
|
|
residual = None
|
|
hidden_states, residual = self.decoder(
|
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
|
)
|
|
|
|
if not forward_batch.forward_mode.is_idle():
|
|
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
nn.Module.__init__(self)
|
|
self.config = config
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.quant_config = quant_config
|
|
self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
|
|
|
|
self.model = DeepseekModelNextN(
|
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("model.shared_head.head", prefix),
|
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
|
)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.lm_head, forward_batch
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
super().load_weights(weights, is_nextn=True)
|
|
|
|
|
|
EntryClass = [DeepseekV3ForCausalLMNextN]
|