[Bugfix] Fix matmul allreduce precision issue by using original weight (#4939)

### What this PR does / why we need it?

This PR fixes the precision issue from improper Tensor maintenance in
`vllm_ascend/ops/linear_op.py` under the Verl reinforcement learning
(RL) scenario. issue:
https://github.com/vllm-project/vllm-ascend/issues/5747
Key changes:
1. Remove the custom class member `self.weight_t` in
`vllm_ascend/ops/linear_op.py`;
2. Adjust the input logic of the `npu_mm_all_reduce_base` operator to
directly fetch weight parameters from the model's `nn.Parameters`,
instead of using pre-created Tensors.

> In the vllm model, it is recommended to avoid creating additional
parameter copies (such as self.weight_t) for computation; if already
created, they must be synchronized with the model's original parameters.
This is because parameter synchronization between training and inference
in the Verl reinforcement learning (RL) scenario may cause memory
address changes to nn.Parameters, and unsynchronized extra Tensors will
reference old memory without updating with the parameters—ultimately
leading to precision issues.
### Does this PR introduce _any_ user-facing change?
No.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: icerain-alt <450125138@qq.com>
Co-authored-by: Shangwei-Li <lishangwei@mail.ustc.edu.cn>
This commit is contained in:
ice_rain
2026-01-09 16:05:32 +08:00
committed by GitHub
parent 64d29875f9
commit 09682e0751

View File

@@ -423,7 +423,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel, output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t, self.layer.weight.t(),
self.hcomm_info, self.hcomm_info,
bias=bias_) bias=bias_)
else: else:
@@ -450,10 +450,6 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
cls._HCOMM_INFO = group.get_hccl_comm_name(rank) cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
return cls._HCOMM_INFO return cls._HCOMM_INFO
def update_attrs(self):
super().update_attrs()
self.weight_t = self.layer.weight.t()
class SequenceColumnParallelOp(CustomColumnParallelOp): class SequenceColumnParallelOp(CustomColumnParallelOp):