[MOE]move weight transpose to wakeup for RL secnarios (#4626)
### What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a
transpose operation to the weights. For a cleaner architecture, the
weight transpose module was moved to wakeup.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -56,29 +56,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
super().__init__(moe=moe)
|
||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||
self.transpose = True
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
if self.transpose:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||
requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
|
||||
self.transpose = False
|
||||
else:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||
requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
|
||||
):
|
||||
@@ -389,61 +378,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
|
||||
# Ensure training and inference weight shapes match during RL weight updates
|
||||
if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \
|
||||
loaded_weight.shape[1] != expert_data.shape[1] and \
|
||||
loaded_weight.shape[0] != expert_data.shape[0]
|
||||
):
|
||||
shard_dim = int(not shard_dim)
|
||||
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
|
||||
return loaded_weight, shard_dim
|
||||
|
||||
def _load_w13(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
loaded_weight, shard_dim = self.transpose_weight(
|
||||
loaded_weight, expert_data, shard_dim)
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
loaded_weight, shard_dim = self.transpose_weight(
|
||||
loaded_weight, expert_data, shard_dim)
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user