[Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
This commit is contained in:
@@ -15,8 +15,8 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
# import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
@@ -1,3 +1,20 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
import torch
|
||||
@@ -177,51 +194,10 @@ class KunlunOps:
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
if not is_neox_style:
|
||||
if cos_sin_cache.dtype == torch.float16:
|
||||
cos_sin_cache = cos_sin_cache.to(torch.float32)
|
||||
positions = positions.to(torch.int)
|
||||
if positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
|
||||
xtorch_ops.rotary_embedding_gptj(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache)
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
if query_x_dim != query_x.dim():
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
return query, key
|
||||
|
||||
# TODO: need opt
|
||||
if cos_sin_cache.dim() == 4:
|
||||
max_seq_len = cos_sin_cache.shape[2]
|
||||
head_dim = cos_sin_cache.shape[3]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
|
||||
|
||||
# 重塑 query 和 key 的形状
|
||||
num_tokens = query_x.shape[0]
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
||||
|
||||
# # 确保形状正确
|
||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
@@ -234,8 +210,6 @@ class KunlunOps:
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
|
||||
# query.data = query_x
|
||||
# key.data = key_x
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@@ -433,6 +407,121 @@ class KunlunOps:
|
||||
|
||||
return out
|
||||
|
||||
def _dbg(x):
|
||||
if torch.is_tensor(x):
|
||||
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
||||
return (type(x), x)
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
moe_top_k: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = w2.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=1,
|
||||
)
|
||||
|
||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,moe_top_k,
|
||||
w1.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=w2,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe_ep(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -487,42 +576,6 @@ class KunlunOps:
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_multi_head_latent_page_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -68,7 +68,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights)
|
||||
linear_weights=linear_weights,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
@@ -81,7 +82,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
@@ -99,96 +102,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
# fused_moe do not support expert number > 400
|
||||
elif layer.local_num_experts > 400:
|
||||
hidden_states = x
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
router_logits = router_logits.float()
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=layer.w13_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=layer.w2_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
@@ -200,7 +113,9 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
|
||||
@@ -57,6 +57,8 @@ def vllm_kunlun_forward_cuda(
|
||||
)
|
||||
return out
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
|
||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
@staticmethod
|
||||
|
||||
@@ -3,14 +3,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ReplicatedLinear,
|
||||
UnquantizedLinearMethod,
|
||||
ColumnParallelLinear
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -59,4 +75,361 @@ def create_weights(
|
||||
|
||||
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
|
||||
UnquantizedLinearMethod.create_weights = create_weights
|
||||
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
v_head_size: int | None = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.v_head_size = v_head_size if v_head_size is not None else head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (
|
||||
self.num_heads * self.head_size
|
||||
+ self.num_kv_heads * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size
|
||||
) * tp_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
shard_offset_mapping = {
|
||||
"q": 0,
|
||||
"k": self.num_heads * self.head_size,
|
||||
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
"total": (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size,
|
||||
}
|
||||
return shard_offset_mapping.get(loaded_shard_id)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads * self.head_size,
|
||||
"k": self.num_kv_heads * self.head_size,
|
||||
"v": self.num_kv_heads * self.v_head_size,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
def _load_fused_module_from_checkpoint(
|
||||
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Handle special case for models where QKV layers are already
|
||||
fused on disk. In this case, we have no shard id. This function
|
||||
determines the shard id by splitting these layers and then calls
|
||||
the weight loader using the shard id.
|
||||
|
||||
An example of a model with these fused layers:
|
||||
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
||||
"""
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
(
|
||||
"k",
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
]
|
||||
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if (
|
||||
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and param.packed_dim == param.output_dim
|
||||
):
|
||||
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size, shard_offset=shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def weight_loader_v2(
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
|
||||
)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
|
||||
# Note(simon): This is needed for Qwen3's fp8 quantization.
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
assert self.quant_method is not None
|
||||
# Assume the weight block size has been set by quant method
|
||||
assert hasattr(self, "weight_block_size")
|
||||
weight_block_size = self.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
if loaded_shard_id is not None:
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
else:
|
||||
param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv).
|
||||
# (e.g., Phi-3's qkv_proj).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
(
|
||||
"k",
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.packed_factor
|
||||
shard_offset = shard_offset // param.packed_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.total_num_heads * self.head_size),
|
||||
"k": (
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
"v": (
|
||||
(self.total_num_heads + self.total_num_kv_heads)
|
||||
* self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.total_num_heads + self.total_num_kv_heads)
|
||||
* self.head_size
|
||||
+ self.total_num_kv_heads * self.v_head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, shard_id
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.v_head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.packed_factor
|
||||
shard_offset = shard_offset // param.packed_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
"k": (
|
||||
self.num_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
),
|
||||
"v": (
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
self.num_kv_heads * self.v_head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id
|
||||
)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
shard_rank = self.tp_rank
|
||||
else:
|
||||
shard_rank = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_rank * shard_size
|
||||
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id
|
||||
)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions."
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -8,14 +8,8 @@ from typing import List, Optional, Tuple
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
if current_platform.is_kunlun():
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
@@ -70,7 +70,7 @@ def vllm_kunlun_forward_cuda(
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
query, key = ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
@@ -143,14 +143,11 @@ def vllm_kunlun_mrope_forward_cuda(
|
||||
|
||||
return query, key
|
||||
|
||||
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
|
||||
@@ -1,143 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int,
|
||||
rank: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f + offset, index_l + offset
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank,
|
||||
offset=offset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
padded_added_vocab_end_index: int
|
||||
|
||||
org_vocab_start_index: int
|
||||
org_vocab_end_index: int
|
||||
added_vocab_start_index: int
|
||||
added_vocab_end_index: int
|
||||
|
||||
@property
|
||||
def num_org_elements(self) -> int:
|
||||
return self.org_vocab_end_index - self.org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements(self) -> int:
|
||||
return self.added_vocab_end_index - self.added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return (self.padded_org_vocab_end_index -
|
||||
self.padded_org_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return (self.padded_added_vocab_end_index -
|
||||
self.padded_added_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
return self.num_org_elements_padded - self.num_org_elements
|
||||
|
||||
@property
|
||||
def num_added_vocab_padding(self) -> int:
|
||||
return self.num_added_elements_padded - self.num_added_elements
|
||||
|
||||
@property
|
||||
def num_elements_padded(self) -> int:
|
||||
return self.num_org_elements_padded + self.num_added_elements_padded
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert (self.padded_org_vocab_start_index
|
||||
<= self.padded_org_vocab_end_index)
|
||||
assert (self.padded_added_vocab_start_index
|
||||
<= self.padded_added_vocab_end_index)
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert (self.added_vocab_start_index
|
||||
<= self.padded_added_vocab_start_index)
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.num_org_elements <= self.num_org_elements_padded
|
||||
assert self.num_added_elements <= self.num_added_elements_padded
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend="aot_eager")
|
||||
def get_masked_input_and_mask(
|
||||
@@ -159,319 +27,25 @@ def get_masked_input_and_mask(
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
def forward_native_kunlun(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
@CustomOp.register("vllm_kunlun_vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
In order to support various loading methods, we ensure that LoRA-added
|
||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
||||
and place them in the same tensor.
|
||||
In this example, we will have the original vocab size = 1010,
|
||||
added vocab size = 16 and padding to 64. Therefore, the total
|
||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
||||
1024, add 16, and then pad to 1088).
|
||||
Therefore, the tensor format looks like the following:
|
||||
TP1, rank 0 (no sharding):
|
||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
||||
|
||||
TP2, rank 0:
|
||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
quant_config: quant config for the layer
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self) is VocabParallelEmbedding
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
||||
tp_size))
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
offset=org_vocab_size))
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
||||
org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
||||
vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
logits for sampling.
|
||||
|
||||
During sampling, we gather logits from all ranks. The relationship
|
||||
of index->token_id will follow the same format as outlined in the class
|
||||
docstring. However, after the gather, we want to reindex the final
|
||||
logits tensor to map index->token_id one-to-one (the index is always
|
||||
equal the token_id it corresponds to). The indices returned by this
|
||||
method allow us to do that.
|
||||
"""
|
||||
if self.tp_size < 2:
|
||||
return None
|
||||
|
||||
base_embeddings: list[int] = []
|
||||
added_embeddings: list[int] = []
|
||||
padding: list[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
range_start = self.num_embeddings_per_partition * tp_rank
|
||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
||||
base_embeddings.extend(
|
||||
range(range_start,
|
||||
range_start + shard_indices.num_org_elements))
|
||||
padding.extend(
|
||||
range(range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded))
|
||||
added_embeddings.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements))
|
||||
padding.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded))
|
||||
assert (range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded == range_end)
|
||||
ret = base_embeddings + added_embeddings + padding
|
||||
assert len(ret) == self.num_embeddings_padded
|
||||
return ret
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
# If the parameter is a gguf weight, then load it directly.
|
||||
if getattr(param, "is_gguf_weight_type", None):
|
||||
param.data.copy_(loaded_weight)
|
||||
param.weight_type = loaded_weight.item()
|
||||
return
|
||||
elif isinstance(param, UninitializedParameter):
|
||||
shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
shape[output_dim] = self.num_embeddings_per_partition
|
||||
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
assert param.data.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# Shard indexes for loading the weight
|
||||
start_idx = self.shard_indices.org_vocab_start_index
|
||||
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
||||
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
packed_factor = param.packed_factor if isinstance(
|
||||
param, BasevLLMParameter) else param.pack_factor
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||
param.packed_factor)
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||
|
||||
# Copy the data. Select chunk corresponding to current shard.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
||||
s += f', tp_size={self.tp_size}'
|
||||
return s
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
|
||||
Output logits weight matrices used in the Sampler. The weight and bias
|
||||
tensors are padded to make sure they are divisible by the number of
|
||||
model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
bias: whether to use bias.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
||||
"""Tie the weights with word embeddings."""
|
||||
# GGUF quantized embed_tokens.
|
||||
if self.quant_config and self.quant_config.get_name() == "gguf":
|
||||
return embed_tokens
|
||||
else:
|
||||
self.weight = embed_tokens.weight
|
||||
return self
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||
VocabParallelEmbedding.forward_native = forward_native_kunlun
|
||||
Reference in New Issue
Block a user