141 lines
5.0 KiB
Python
141 lines
5.0 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from functools import partial
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
from vllm.distributed import (get_pp_group, split_tensor_along_last_dim,
|
|
tensor_model_parallel_all_gather)
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
|
|
InternLM2MLP, InternLM2Model)
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
|
def internlm2_attention_split_qkv(self, qkv: torch.Tensor):
|
|
seq_len = qkv.shape[1]
|
|
if self.tp_size > 1:
|
|
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
|
|
qkv = tensor_model_parallel_all_gather(qkv)
|
|
qkv = torch.split(qkv, qkv_map, dim=-1)
|
|
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
|
|
qkv = torch.cat(qkv, dim=-1)
|
|
|
|
qkv = qkv.view(seq_len, self.total_num_kv_heads, self.key_value_groups + 2,
|
|
self.head_dim)
|
|
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
|
|
q = q.reshape(seq_len, self.q_size * self.tp_size).unsqueeze(0)
|
|
k = k.reshape(seq_len, self.kv_size * self.tp_size).unsqueeze(0)
|
|
v = v.reshape(seq_len, self.kv_size * self.tp_size).unsqueeze(0)
|
|
|
|
if self.tp_size > 1:
|
|
splitter = partial(split_tensor_along_last_dim,
|
|
num_partitions=self.tp_size)
|
|
q = splitter(q)[self.tp_rank]
|
|
k = splitter(k)[self.tp_rank]
|
|
v = splitter(v)[self.tp_rank]
|
|
return q, k, v
|
|
|
|
|
|
def internlm2_attention_forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.wqkv(hidden_states)
|
|
q, k, v = self.split_qkv(qkv)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.wo(attn_output)
|
|
return output
|
|
|
|
|
|
def internlm2_model_forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
hidden_states = hidden_states.unsqueeze(0)
|
|
|
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states":
|
|
hidden_states.squeeze(0) if hidden_states is not None else None,
|
|
"residual":
|
|
residual.squeeze(0) if residual is not None else None
|
|
})
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states.squeeze(0)
|
|
|
|
|
|
def internlm2_mlp_init(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super(InternLM2MLP, self).__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.gate_up_proj.no_need_cross = True
|
|
self.w2 = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.w2",
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now.")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
|
|
InternLM2Attention.split_qkv = internlm2_attention_split_qkv
|
|
InternLM2Attention.forward = internlm2_attention_forward
|
|
InternLM2Model.forward = internlm2_model_forward
|
|
InternLM2MLP.__init__ = internlm2_mlp_init
|