Improve weight loading and code style (#3174)
This commit is contained in:
@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||
)
|
||||
|
||||
self.gather_output = gather_output
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
if tp_rank is None:
|
||||
@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
|
||||
param.load_column_parallel_weight(
|
||||
loaded_weight,
|
||||
tp_rank=self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
prefix=prefix,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.prefix = prefix
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
load_presharded_attn: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
self.use_presharded_weights = load_presharded_attn
|
||||
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
prefix=prefix,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size=shard_size, shard_offset=shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def weight_loader_v2(
|
||||
@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param, orig_qkv_offsets, shard_id
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
|
||||
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -141,6 +142,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.activation = activation
|
||||
|
||||
if quant_config is None:
|
||||
@@ -184,6 +186,7 @@ class EPMoE(torch.nn.Module):
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
@@ -257,16 +260,20 @@ class EPMoE(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
@@ -312,7 +319,6 @@ class EPMoE(torch.nn.Module):
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
@@ -357,7 +363,6 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
param.data[expert_id] = loaded_weight
|
||||
elif shard_id == "w1":
|
||||
|
||||
@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
|
||||
def load_qkv_weight(
|
||||
self,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
use_presharded_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
param_data = self.data
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
|
||||
Reference in New Issue
Block a user