Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -8,10 +8,7 @@ from weakref import WeakValueDictionary
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
|
||||
__all__ = [
|
||||
@@ -197,7 +194,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert param_data.shape == loaded_weight.shape, f"param_data.shape({param_data.shape}) != loaded_weight.shape({loaded_weight.shape})"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -218,16 +215,24 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
return self._input_dim
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, self.tp_rank * shard_size, shard_size
|
||||
)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
k = (loaded_weight.shape[0] if loaded_weight.ndim == 1
|
||||
else loaded_weight.shape[self.input_dim])
|
||||
assert k % self.tp_size == 0, (
|
||||
f"Row dimension({k}) must be divisible by tp_size({self.tp_size})!")
|
||||
shard_size = k // self.tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.input_dim, start_idx, shard_size)
|
||||
|
||||
if loaded_weight.ndim == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
if self.data.shape == loaded_weight.shape:
|
||||
self.data.copy_(loaded_weight)
|
||||
else:
|
||||
target_slice = self.data.narrow(self.input_dim, 0, shard_size)
|
||||
target_slice.copy_(loaded_weight)
|
||||
|
||||
|
||||
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
|
||||
Reference in New Issue
Block a user