359 lines
13 KiB
Python
359 lines
13 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from collections.abc import Iterable
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch_br
|
|
from torch import nn
|
|
from transformers import GptOssConfig
|
|
|
|
import vllm
|
|
import vllm.model_executor.models.gpt_oss
|
|
from vllm.attention import Attention, AttentionType
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.utils import (extract_layer_index,
|
|
is_pp_missing_parameter)
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import cdiv
|
|
from vllm_br import envs
|
|
|
|
|
|
class OAIAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: GptOssConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.layer_idx = extract_layer_index(prefix)
|
|
self.head_dim = config.head_dim
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
dtype=torch.float32,
|
|
rope_scaling={
|
|
"rope_type":
|
|
"yarn",
|
|
"factor":
|
|
config.rope_scaling["factor"],
|
|
"original_max_position_embeddings":
|
|
config.rope_scaling["original_max_position_embeddings"],
|
|
"beta_fast":
|
|
config.rope_scaling["beta_fast"],
|
|
"beta_slow":
|
|
config.rope_scaling["beta_slow"],
|
|
},
|
|
is_neox_style=True,
|
|
)
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
attention_sink_dtype = torch.float32
|
|
self.sinks = torch.nn.Parameter(
|
|
torch.empty(config.num_attention_heads // tp_size,
|
|
dtype=attention_sink_dtype,
|
|
requires_grad=False))
|
|
|
|
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
|
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = config.rope_theta
|
|
|
|
self.qkv = QKVParallelLinear(
|
|
hidden_size=self.hidden_size,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.num_attention_heads,
|
|
total_num_kv_heads=self.num_key_value_heads,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
|
|
self.o_proj = RowParallelLinear(
|
|
input_size=self.num_attention_heads * self.head_dim,
|
|
output_size=self.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
self.num_local_attention_heads = config.num_attention_heads // tp_size
|
|
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
|
|
|
|
# Only apply sliding window to every other layer
|
|
sliding_window = (config.sliding_window if self.layer_idx %
|
|
2 == 0 else None)
|
|
self.attn = Attention(
|
|
self.num_local_attention_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_local_key_value_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
per_layer_sliding_window=sliding_window,
|
|
attn_type=AttentionType.DECODER,
|
|
prefix=f"{prefix}.attn",
|
|
sinks=self.sinks,
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor,
|
|
positions: torch.Tensor) -> torch.Tensor:
|
|
qkv, _ = self.qkv(hidden_states)
|
|
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
|
q, k, v = torch_br.split_w_sbp_infer(
|
|
qkv, [self.q_size, self.kv_size, self.kv_size])
|
|
else:
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
|
dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
v = v.contiguous()
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
vllm.model_executor.models.gpt_oss.OAIAttention = OAIAttention
|
|
|
|
|
|
class MLPBlock(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
layer_idx: int,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
|
|
|
self.layer_idx = layer_idx
|
|
self.num_experts = config.num_local_experts
|
|
self.experts_per_token = config.num_experts_per_tok
|
|
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
self.router = torch.nn.Linear(config.hidden_size,
|
|
config.num_local_experts,
|
|
dtype=torch.bfloat16)
|
|
assert config.intermediate_size % self.world_size == 0
|
|
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
reduce_results=True,
|
|
renormalize=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.experts",
|
|
apply_router_weight_on_input=False,
|
|
has_bias=True,
|
|
activation="swigluoai",
|
|
is_sequence_parallel=self.is_sequence_parallel)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
final_hidden_states = self.experts(hidden_states=x.squeeze(0),
|
|
router_logits=self.router.weight)
|
|
|
|
if hasattr(final_hidden_states, 'all_reduced'):
|
|
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
|
|
delattr(final_hidden_states, 'all_reduced')
|
|
elif self.tp_size > 1:
|
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
|
final_hidden_states)
|
|
return final_hidden_states
|
|
|
|
|
|
vllm.model_executor.models.gpt_oss.MLPBlock = MLPBlock
|
|
|
|
|
|
def GptOssModel_forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
x = inputs_embeds
|
|
else:
|
|
x = self.get_input_embeddings(input_ids)
|
|
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
x = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
residual = residual.unsqueeze(0)
|
|
|
|
x = x.unsqueeze(0)
|
|
aux_hidden_states = []
|
|
for i in range(self.start_layer, self.end_layer):
|
|
layer = self.layers[i]
|
|
if i in self.aux_hidden_state_layers:
|
|
aux_hidden_states.append(x if residual is None else x + residual)
|
|
x, residual = layer(x, positions, residual)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states":
|
|
x.squeeze(0),
|
|
"residual":
|
|
residual.squeeze(0) if residual is not None else None,
|
|
})
|
|
x, _ = self.norm(x, residual)
|
|
|
|
if len(aux_hidden_states) > 0:
|
|
return x, aux_hidden_states
|
|
return x.squeeze(0)
|
|
|
|
|
|
vllm.model_executor.models.gpt_oss.GptOssModel.forward = GptOssModel_forward
|
|
|
|
|
|
def GptOssModel_load_weights_other(
|
|
self,
|
|
ep_rank_end: int,
|
|
ep_rank_start: int,
|
|
heads_per_rank: int,
|
|
head_start: int,
|
|
weights: Iterable[tuple[str, torch.Tensor]],
|
|
stacked_params_mapping: list[tuple[str, ...]],
|
|
) -> set[str]:
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
|
|
use_ep = self.parallel_config.enable_expert_parallel
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
intermediate_size = self.config.intermediate_size
|
|
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
|
# Calculate common slicing bounds for current rank
|
|
tp_rank_start = tp_rank * per_rank_intermediate_size
|
|
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
|
intermediate_size)
|
|
|
|
for name, weight in weights:
|
|
# Skip layers on other devices.
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
if ".w13_weight" in name:
|
|
# Handle MLP gate and up projection weights
|
|
# Extract gate and up projection parts
|
|
if use_ep:
|
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
|
else:
|
|
narrow_weight = weight[:, :, 2 * tp_rank_start:2 * tp_rank_end]
|
|
|
|
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
|
param = params_dict[name]
|
|
|
|
param.copy_(narrow_weight)
|
|
loaded_params.add(name)
|
|
continue
|
|
elif ".w2_weight" in name:
|
|
# Handle MLP down projection weights
|
|
if use_ep:
|
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
|
else:
|
|
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
|
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
|
param = params_dict[name]
|
|
|
|
param.copy_(narrow_weight)
|
|
loaded_params.add(name)
|
|
continue
|
|
elif ".w13_bias" in name:
|
|
# Handle MLP gate and up projection biases
|
|
# Extract gate and up projection bias parts
|
|
if use_ep:
|
|
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
|
else:
|
|
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end]
|
|
|
|
param = params_dict[name]
|
|
param.copy_(narrow_weight)
|
|
loaded_params.add(name)
|
|
continue
|
|
elif ".w2_bias" in name:
|
|
# Handle MLP down projection bias
|
|
if use_ep:
|
|
weight = weight[ep_rank_start:ep_rank_end, ...]
|
|
else:
|
|
# (only load on rank 0 to avoid duplication)
|
|
if tp_rank != 0:
|
|
weight.zero_()
|
|
param = params_dict[name]
|
|
param.copy_(weight)
|
|
loaded_params.add(name)
|
|
continue
|
|
elif "sinks" in name:
|
|
# Handle attention sinks (distributed across ranks)
|
|
param = params_dict[name]
|
|
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
|
param.data.copy_(narrow_weight)
|
|
loaded_params.add(name)
|
|
continue
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
if weight_loader == default_weight_loader:
|
|
weight_loader(param, weight)
|
|
else:
|
|
weight_loader(param, weight, shard_id)
|
|
break
|
|
else:
|
|
# Handle all other weights with potential renaming
|
|
if name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
vllm.model_executor.models.gpt_oss.GptOssModel._load_weights_other = GptOssModel_load_weights_other
|