Improve weight loading and code style (#3174)

This commit is contained in:
Lianmin Zheng
2025-01-27 03:00:41 -08:00
committed by GitHub
parent 351a72d40b
commit 53cef81587
11 changed files with 171 additions and 65 deletions

View File

@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
use_presharded_weights: bool = False,
):
super().__init__(
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
)
self.gather_output = gather_output
self.use_presharded_weights = use_presharded_weights
# Divide the weight matrix along the last dimension.
if tp_rank is None:
@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
param.load_column_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix=prefix,
tp_rank=tp_rank,
tp_size=tp_size,
use_presharded_weights=use_presharded_weights,
)
self.prefix = prefix
def weight_loader(
self,
@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
load_presharded_attn: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
self.use_presharded_weights = load_presharded_attn
super().__init__(
input_size=input_size,
@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix=prefix,
tp_rank=tp_rank,
tp_size=tp_size,
use_presharded_weights=self.use_presharded_weights,
)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
if not self.use_presharded_weights:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(
@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def weight_loader(
@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
param, orig_qkv_offsets, shard_id
)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
if not self.use_presharded_weights:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for for AQLM codebooks.

View File

@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
):
super().__init__()
@@ -141,6 +142,7 @@ class EPMoE(torch.nn.Module):
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
self.activation = activation
if quant_config is None:
@@ -184,6 +186,7 @@ class EPMoE(torch.nn.Module):
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -257,16 +260,20 @@ class EPMoE(torch.nn.Module):
dtype=torch.float32,
device=hidden_states.device,
)
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
# GroupGemm-1
down_output = torch.empty(
@@ -312,7 +319,6 @@ class EPMoE(torch.nn.Module):
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
@@ -357,7 +363,6 @@ class EPMoE(torch.nn.Module):
)
return
expert_data = param.data[expert_id]
if shard_id == "w2":
param.data[expert_id] = loaded_weight
elif shard_id == "w1":

View File

@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
def load_qkv_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
**kwargs,
):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight)
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs
**kwargs,
):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs
**kwargs,
):
self._packed_factor = packed_factor
self._packed_dim = packed_dim