diff --git a/vllm-v0.6.2/vllm/model_executor/layers/linear.py b/vllm-v0.6.2/vllm/model_executor/layers/linear.py index 2e66428..a53d5e5 100644 --- a/vllm-v0.6.2/vllm/model_executor/layers/linear.py +++ b/vllm-v0.6.2/vllm/model_executor/layers/linear.py @@ -146,6 +146,7 @@ class LinearBase(torch.nn.Module): skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + return_bias: If False, return only output tensor instead of (output, bias) tuple. """ def __init__( @@ -156,6 +157,7 @@ class LinearBase(torch.nn.Module): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + return_bias: bool = True, ): super().__init__() @@ -163,6 +165,7 @@ class LinearBase(torch.nn.Module): self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add + self.return_bias = return_bias if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype @@ -198,13 +201,15 @@ class ReplicatedLinear(LinearBase): skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + return_bias: bool = True): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, - prefix=prefix) + prefix=prefix, + return_bias=return_bias) # All the linear layer supports quant method. assert self.quant_method is not None @@ -238,6 +243,9 @@ class ReplicatedLinear(LinearBase): bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) + + if not self.return_bias: + return output output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -281,9 +289,10 @@ class ColumnParallelLinear(LinearBase): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: str = ""): + prefix: str = "", + return_bias: bool = True): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + quant_config, prefix, return_bias=return_bias) self.gather_output = gather_output @@ -375,6 +384,9 @@ class ColumnParallelLinear(LinearBase): output = tensor_model_parallel_all_gather(output_parallel) else: output = output_parallel + + if not self.return_bias: + return output output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -418,7 +430,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + return_bias: bool = True): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -429,7 +442,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + return_bias=return_bias) def weight_loader(self, param: Parameter, @@ -653,7 +667,8 @@ class QKVParallelLinear(ColumnParallelLinear): skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + return_bias: bool = True): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -686,7 +701,8 @@ class QKVParallelLinear(ColumnParallelLinear): skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + return_bias=return_bias) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -980,9 +996,10 @@ class RowParallelLinear(LinearBase): params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + return_bias: bool = True): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + quant_config, prefix, return_bias=return_bias) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1086,8 +1103,9 @@ class RowParallelLinear(LinearBase): else: output = output_parallel + if not self.return_bias: + return output output_bias = self.bias if self.skip_bias_add else None - return output, output_bias def extra_repr(self) -> str: