Improve linear.py to load sharded weights & remove the dependency of Parameters from vllm (#2784)
Co-authored-by: SangBin Cho rkooo567@gmail.com
This commit is contained in:
@@ -66,7 +66,14 @@ class AttentionBackend(ABC):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
||||
else:
|
||||
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
|
||||
return self.forward_extend(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
|
||||
@@ -347,6 +347,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
if not self.forward_metadata.use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -359,7 +361,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
@@ -368,7 +370,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
if self.forward_metadata.extend_no_prefix:
|
||||
|
||||
@@ -18,14 +18,15 @@ from vllm.distributed import (
|
||||
|
||||
# workaround
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.parameter import (
|
||||
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -94,6 +95,62 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
def load_column_qkv_weight(
|
||||
self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
|
||||
):
|
||||
if (
|
||||
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and self.output_dim == self.packed_dim
|
||||
):
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size
|
||||
)
|
||||
|
||||
param_data = self.data
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def load_column_parallel_weight(
|
||||
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
||||
):
|
||||
if isinstance(self, _ColumnvLLMParameter):
|
||||
if not use_presharded_weights:
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
else:
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def load_row_parallel_weight(
|
||||
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
||||
):
|
||||
if isinstance(self, RowvLLMParameter):
|
||||
if not use_presharded_weights:
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
else:
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@@ -287,6 +344,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||
@@ -295,7 +354,11 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_rank is None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if tp_size is None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||
assert self.quant_method is not None
|
||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
@@ -336,7 +399,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
@@ -356,7 +418,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
# no need to narrow here
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
@@ -364,7 +426,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
@@ -373,7 +437,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
load_column_parallel_weight(param, loaded_weight, self.tp_rank)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
@@ -393,7 +457,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", output_features={self.output_size_per_partition}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
s += f", tp_size={self.tp_size}"
|
||||
s += f", gather_output={self.gather_output}"
|
||||
return s
|
||||
|
||||
@@ -431,10 +495,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_rank is None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if tp_size is None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
@@ -444,6 +516,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
@@ -463,12 +537,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
@@ -494,7 +565,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
param_data, loaded_weight, 0
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
current_shard_offset = 0
|
||||
@@ -522,11 +595,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
@@ -545,10 +616,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
@@ -572,7 +643,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"the same for all partitions."
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def _load_fused_module_from_checkpoint(
|
||||
@@ -629,26 +702,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
||||
) // tp_size
|
||||
) // self.tp_size
|
||||
shard_size = (
|
||||
(self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
|
||||
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
||||
// block_n
|
||||
// self.tp_size
|
||||
)
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
|
||||
@@ -689,6 +763,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@@ -697,7 +773,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_rank is None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if tp_size is None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
@@ -724,6 +804,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
@@ -814,13 +896,24 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
)
|
||||
if isinstance(param, _ColumnvLLMParameter):
|
||||
load_column_qkv_weight(
|
||||
param,
|
||||
loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
else:
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
@@ -840,12 +933,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
@@ -872,7 +962,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data, loaded_weight, 0
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
@@ -934,7 +1026,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
@@ -984,9 +1075,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
shard_id = self.tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
shard_id = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
@@ -1014,7 +1105,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"for all partitions."
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -1055,6 +1148,9 @@ class RowParallelLinear(LinearBase):
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||
@@ -1064,10 +1160,14 @@ class RowParallelLinear(LinearBase):
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_rank is None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if tp_size is None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
@@ -1101,8 +1201,6 @@ class RowParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
@@ -1116,15 +1214,19 @@ class RowParallelLinear(LinearBase):
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
weight_shape = list(loaded_weight.shape)
|
||||
if input_dim:
|
||||
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
||||
weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if input_dim is not None and not use_bitsandbytes_4bit:
|
||||
if (
|
||||
input_dim is not None
|
||||
and not use_bitsandbytes_4bit
|
||||
and not self.use_presharded_weights
|
||||
):
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
@@ -1132,7 +1234,9 @@ class RowParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
|
||||
@@ -1143,17 +1247,21 @@ class RowParallelLinear(LinearBase):
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
load_row_parallel_weight(
|
||||
param,
|
||||
loaded_weight,
|
||||
self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module):
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
@@ -395,10 +397,7 @@ class FusedMoE(torch.nn.Module):
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
|
||||
431
python/sglang/srt/layers/parameter.py
Normal file
431
python/sglang/srt/layers/parameter.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Adapted from vLLM (0.6.4.post1).
|
||||
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
"PerTensorScaleParameter",
|
||||
"ModelWeightParameter",
|
||||
"ChannelQuantScaleParameter",
|
||||
"GroupQuantScaleParameter",
|
||||
"PackedColumnParameter",
|
||||
"RowvLLMParameter",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasevLLMParameter(Parameter):
|
||||
"""
|
||||
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
|
||||
by taking in a linear weight loader. Will copy the loaded weight
|
||||
into the parameter when the provided weight loader is called.
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, **kwargs):
|
||||
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
def __init__(self, data: torch.Tensor, weight_loader: Callable):
|
||||
"""
|
||||
Initialize the BasevLLMParameter
|
||||
|
||||
:param data: torch tensor with the parameter data
|
||||
:param weight_loader: weight loader callable
|
||||
|
||||
:returns: a torch.nn.parameter
|
||||
"""
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
|
||||
@property
|
||||
def weight_loader(self):
|
||||
return self._weight_loader
|
||||
|
||||
def _assert_and_load(self, loaded_weight: torch.Tensor):
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
|
||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
Private class defining weight loading functionality
|
||||
(load_merged_column_weight, load_qkv_weight)
|
||||
for parameters being loaded into linear layers with column
|
||||
parallelism. This includes QKV and MLP layers which are
|
||||
not already fused on disk. Requires an output dimension
|
||||
to be defined. Called within the weight loader of
|
||||
each of the column parallel linear layers.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dim: int, **kwargs):
|
||||
self._output_dim = output_dim
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def output_dim(self):
|
||||
return self._output_dim
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
use_presharded_weights = kwargs.get("use_presharded_weights")
|
||||
if (
|
||||
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and self.packed_dim == self.output_dim
|
||||
):
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size
|
||||
)
|
||||
|
||||
param_data = self.data
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
shard_id = kwargs.get("shard_id")
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
if (
|
||||
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and self.output_dim == self.packed_dim
|
||||
):
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size
|
||||
)
|
||||
|
||||
param_data = self.data
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
Parameter class defining weight_loading functionality
|
||||
(load_row_parallel_weight) for parameters being loaded
|
||||
into linear layers with row parallel functionality.
|
||||
Requires an input_dim to be defined.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, **kwargs):
|
||||
self._input_dim = input_dim
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def input_dim(self):
|
||||
return self._input_dim
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
use_presharded_weights = kwargs.get("use_presharded_weights")
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for linear layer weights. Uses both column and
|
||||
row parallelism.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for weight scales loaded for weights with
|
||||
grouped quantization. Uses both column and row parallelism.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
||||
"""
|
||||
Parameter class for weight scales loaded for weights with
|
||||
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PerTensorScaleParameter(BasevLLMParameter):
|
||||
"""
|
||||
Parameter class for scales where the number of scales is
|
||||
equivalent to the number of logical matrices in fused linear
|
||||
layers (e.g. for QKV, there are 3 scales loaded from disk).
|
||||
This is relevant to weights with per-tensor quantization.
|
||||
Adds functionality to map the scalers to a shard during
|
||||
weight loading.
|
||||
|
||||
Note: additional parameter manipulation may be handled
|
||||
for each quantization config specifically, within
|
||||
process_weights_after_loading
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in self.qkv_idxs
|
||||
return self.qkv_idxs[shard_id]
|
||||
|
||||
# For row parallel layers, no sharding needed
|
||||
# load weight into parameter as is
|
||||
def load_row_parallel_weight(self, *args, **kwargs):
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def load_merged_column_weight(self, *args, **kwargs):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
def load_qkv_weight(self, *args, **kwargs):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
def load_column_parallel_weight(self, *args, **kwargs):
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def _load_into_shard_id(
|
||||
self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs
|
||||
):
|
||||
"""
|
||||
Slice the parameter data based on the shard id for
|
||||
loading.
|
||||
"""
|
||||
|
||||
param_data = self.data
|
||||
shard_id = self._shard_id_as_int(shard_id)
|
||||
|
||||
# AutoFP8 scales do not have a shape
|
||||
# compressed-tensors scales do have a shape
|
||||
if len(loaded_weight.shape) != 0:
|
||||
assert loaded_weight.shape[0] == 1
|
||||
loaded_weight = loaded_weight[0]
|
||||
|
||||
param_data = param_data[shard_id]
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
"""
|
||||
Parameter for model parameters which are packed on disk
|
||||
and support column parallelism only. See PackedvLLMParameter
|
||||
for more details on the packed properties.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def packed_dim(self):
|
||||
return self._packed_dim
|
||||
|
||||
@property
|
||||
def packed_factor(self):
|
||||
return self._packed_factor
|
||||
|
||||
@property
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
)
|
||||
|
||||
|
||||
class PackedvLLMParameter(ModelWeightParameter):
|
||||
"""
|
||||
Parameter for model weights which are packed on disk.
|
||||
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
|
||||
Extends the ModelWeightParameter to take in the
|
||||
packed factor, the packed dimension, and optionally, marlin
|
||||
tile size for marlin kernels. Adjusts the shard_size and
|
||||
shard_offset for fused linear layers model weight loading
|
||||
by accounting for packing and optionally, marlin tile size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def packed_dim(self):
|
||||
return self._packed_dim
|
||||
|
||||
@property
|
||||
def packed_factor(self):
|
||||
return self._packed_factor
|
||||
|
||||
@property
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
)
|
||||
|
||||
|
||||
def permute_param_layout_(
|
||||
param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs
|
||||
) -> BasevLLMParameter:
|
||||
"""
|
||||
Permute a parameter's layout to the specified input and output dimensions,
|
||||
useful for forcing the parameter into a known layout, for example, if I need
|
||||
a packed (quantized) weight matrix to be in the layout
|
||||
{input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
then I can call:
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||
required, asserting if it cannot get it to the correct layout)
|
||||
"""
|
||||
|
||||
curr_input_dim = getattr(param, "input_dim", None)
|
||||
curr_output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
if curr_input_dim is None or curr_output_dim is None:
|
||||
assert param.data.dim() == 2, (
|
||||
"permute_param_layout_ only supports 2D parameters when either "
|
||||
"input_dim or output_dim is not set"
|
||||
)
|
||||
|
||||
# if one of the dimensions is not set, set it to the opposite of the other
|
||||
# we can only do this since we asserted the parameter is 2D above
|
||||
if curr_input_dim is None:
|
||||
assert curr_output_dim is not None, "either input or output dim must be set"
|
||||
curr_input_dim = (curr_output_dim + 1) % 2
|
||||
if curr_output_dim is None:
|
||||
assert curr_input_dim is not None, "either input or output dim must be set"
|
||||
curr_output_dim = (curr_input_dim + 1) % 2
|
||||
|
||||
# create permutation from the current layout to the layout with
|
||||
# self.input_dim at input_dim and self.output_dim at output_dim preserving
|
||||
# other dimensions
|
||||
perm = [
|
||||
i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]
|
||||
]
|
||||
perm.insert(input_dim, curr_input_dim)
|
||||
perm.insert(output_dim, curr_output_dim)
|
||||
|
||||
if "packed_dim" in kwargs:
|
||||
assert (
|
||||
hasattr(param, "packed_dim")
|
||||
and param.packed_dim == perm[kwargs["packed_dim"]]
|
||||
), "permute_param_layout_ currently doesn't support repacking"
|
||||
|
||||
param.data = param.data.permute(*perm)
|
||||
if hasattr(param, "_input_dim"):
|
||||
param._input_dim = input_dim
|
||||
if hasattr(param, "_output_dim"):
|
||||
param._output_dim = output_dim
|
||||
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
|
||||
param._packed_dim = kwargs["packed_dim"]
|
||||
|
||||
return param
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_packing(
|
||||
shard_size, shard_offset, packed_factor, marlin_tile_size
|
||||
):
|
||||
shard_size = shard_size // packed_factor
|
||||
shard_offset = shard_offset // packed_factor
|
||||
if marlin_tile_size is not None:
|
||||
return _adjust_shard_indexes_for_marlin(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
marlin_tile_size=marlin_tile_size,
|
||||
)
|
||||
return shard_size, shard_offset
|
||||
@@ -25,9 +25,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
|
||||
@@ -12,8 +12,8 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
|
||||
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
|
||||
@@ -99,7 +99,7 @@ class Session:
|
||||
|
||||
if last_req is not None:
|
||||
# trim bos token if it is an append
|
||||
if req.input_ids[0] == tokenizer.bos_token_id:
|
||||
if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
|
||||
req.input_ids = req.input_ids[1:]
|
||||
|
||||
input_ids = (
|
||||
|
||||
@@ -106,6 +106,9 @@ class ForwardMode(IntEnum):
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
def is_decode_or_idle(self):
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||
|
||||
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
NULL = auto()
|
||||
|
||||
@@ -205,7 +205,7 @@ class ModelRunner:
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
elif self.device == "xpu":
|
||||
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
||||
# TODO(liangan1): Just use gloo to bypass the initilization fail
|
||||
# Need to use xccl for xpu backend in the future
|
||||
backend = "gloo"
|
||||
elif self.device == "hpu":
|
||||
@@ -634,7 +634,6 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def init_double_sparsity_channel_config(self, selected_channel):
|
||||
|
||||
selected_channel = "." + selected_channel + "_proj"
|
||||
self.sorted_channels = []
|
||||
# load channel config
|
||||
|
||||
@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
|
||||
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_id = layer_id
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = Grok1Attention(
|
||||
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Grok1DecoderLayer(config, i, quant_config=quant_config)
|
||||
Grok1DecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
if (
|
||||
self.config.num_local_experts > 0
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
|
||||
else:
|
||||
self.use_presharded_weights = False
|
||||
|
||||
self.model = Grok1Model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
use_presharded_weights: Optional[bool] = None,
|
||||
):
|
||||
if use_presharded_weights is None:
|
||||
use_presharded_weights = self.use_presharded_weights
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
stacked_params_mapping = [
|
||||
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if use_presharded_weights:
|
||||
extra_kwargs = {
|
||||
"use_presharded_weights": use_presharded_weights
|
||||
}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
|
||||
load_weight_wrapper(
|
||||
name,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
**extra_kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
|
||||
@@ -544,7 +544,12 @@ def launch_server(
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
tokenizer_manager.image_token_id,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
|
||||
@@ -614,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user