[BugFix] explicitly setting the tensor shape of otp output (#3027)
When MTP and oprojTP are enabled, it triggers the recompilation of the
torchair graph, leading to a decrease in performance, and this PR fixes
this issue.
- vLLM version: v0.10.2
- vLLM main:
486c5599e3
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
@@ -299,6 +299,7 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
|
|
||||||
# otp-specific: Combine partial results across devices
|
# otp-specific: Combine partial results across devices
|
||||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||||
|
output = output.view(input_.shape[0], self.layer.output_size)
|
||||||
|
|
||||||
# Handle bias return based on configuration
|
# Handle bias return based on configuration
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|||||||
Reference in New Issue
Block a user