Improve linear.py to load sharded weights & remove the dependency of Parameters from vllm (#2784)
Co-authored-by: SangBin Cho rkooo567@gmail.com
This commit is contained in:
5
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
5
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
@@ -10,7 +10,10 @@ import triton.language as tl
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe,
|
||||||
|
get_config_file_name,
|
||||||
|
)
|
||||||
|
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,14 @@ class AttentionBackend(ABC):
|
|||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
||||||
else:
|
else:
|
||||||
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
|
return self.forward_extend(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -347,6 +347,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else forward_batch.encoder_out_cache_loc
|
else forward_batch.encoder_out_cache_loc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits_soft_cap = layer.logit_cap
|
||||||
|
|
||||||
if not self.forward_metadata.use_ragged:
|
if not self.forward_metadata.use_ragged:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -359,7 +361,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
causal=not layer.is_cross_attention,
|
causal=not layer.is_cross_attention,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
window_left=layer.sliding_window_size,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
@@ -368,7 +370,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.forward_metadata.extend_no_prefix:
|
if self.forward_metadata.extend_no_prefix:
|
||||||
|
|||||||
@@ -18,14 +18,15 @@ from vllm.distributed import (
|
|||||||
|
|
||||||
# workaround
|
# workaround
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.model_executor.parameter import (
|
|
||||||
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
RowvLLMParameter,
|
RowvLLMParameter,
|
||||||
|
_ColumnvLLMParameter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -94,6 +95,62 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|||||||
return param[shard_id], loaded_weight
|
return param[shard_id], loaded_weight
|
||||||
|
|
||||||
|
|
||||||
|
def load_column_qkv_weight(
|
||||||
|
self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||||
|
and self.output_dim == self.packed_dim
|
||||||
|
):
|
||||||
|
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||||
|
shard_offset=shard_offset, shard_size=shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
param_data = self.data
|
||||||
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||||
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.output_dim, shard_id * shard_size, shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def load_column_parallel_weight(
|
||||||
|
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
||||||
|
):
|
||||||
|
if isinstance(self, _ColumnvLLMParameter):
|
||||||
|
if not use_presharded_weights:
|
||||||
|
shard_size = self.data.shape[self.output_dim]
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.output_dim, tp_rank * shard_size, shard_size
|
||||||
|
)
|
||||||
|
assert self.data.shape == loaded_weight.shape
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
else:
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def load_row_parallel_weight(
|
||||||
|
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
||||||
|
):
|
||||||
|
if isinstance(self, RowvLLMParameter):
|
||||||
|
if not use_presharded_weights:
|
||||||
|
shard_size = self.data.shape[self.input_dim]
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.input_dim, tp_rank * shard_size, shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(loaded_weight.shape) == 0:
|
||||||
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
|
assert self.data.shape == loaded_weight.shape
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
else:
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
class LinearMethodBase(QuantizeMethodBase):
|
class LinearMethodBase(QuantizeMethodBase):
|
||||||
"""Base class for different (maybe quantized) linear methods."""
|
"""Base class for different (maybe quantized) linear methods."""
|
||||||
|
|
||||||
@@ -287,6 +344,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
output_sizes: Optional[List[int]] = None,
|
output_sizes: Optional[List[int]] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||||
@@ -295,7 +354,11 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
if tp_rank is None:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
if tp_size is None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
self.output_size_per_partition = divide(self.output_size, tp_size)
|
||||||
self.output_partition_sizes = [self.output_size_per_partition]
|
self.output_partition_sizes = [self.output_size_per_partition]
|
||||||
@@ -336,7 +399,6 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
# Special case for GGUF
|
# Special case for GGUF
|
||||||
@@ -356,7 +418,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
# no need to narrow here
|
# no need to narrow here
|
||||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||||
shard_size = param_data.shape[output_dim]
|
shard_size = param_data.shape[output_dim]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||||
|
|
||||||
# Special case for loading scales off disk, which often do not
|
# Special case for loading scales off disk, which often do not
|
||||||
@@ -364,7 +426,9 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert (
|
||||||
|
param_data.shape == loaded_weight.shape
|
||||||
|
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
@@ -373,7 +437,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
assert loaded_weight.numel() == 1
|
assert loaded_weight.numel() == 1
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
load_column_parallel_weight(param, loaded_weight, self.tp_rank)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
@@ -393,7 +457,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
s = f"in_features={self.input_size}"
|
s = f"in_features={self.input_size}"
|
||||||
s += f", output_features={self.output_size_per_partition}"
|
s += f", output_features={self.output_size_per_partition}"
|
||||||
s += f", bias={self.bias is not None}"
|
s += f", bias={self.bias is not None}"
|
||||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
s += f", tp_size={self.tp_size}"
|
||||||
s += f", gather_output={self.gather_output}"
|
s += f", gather_output={self.gather_output}"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@@ -431,10 +495,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
):
|
):
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
if tp_rank is None:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
if tp_size is None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||||
|
self.use_presharded_weights = use_presharded_weights
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
output_size=sum(output_sizes),
|
output_size=sum(output_sizes),
|
||||||
@@ -444,6 +516,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
tp_size=tp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
@@ -463,12 +537,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if is_gguf_weight:
|
if is_gguf_weight:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
|
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||||
|
|
||||||
@@ -494,7 +565,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param_data, loaded_weight, 0
|
param_data, loaded_weight, 0
|
||||||
)
|
)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert (
|
||||||
|
param_data.shape == loaded_weight.shape
|
||||||
|
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
return
|
return
|
||||||
current_shard_offset = 0
|
current_shard_offset = 0
|
||||||
@@ -522,11 +595,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
return
|
return
|
||||||
|
|
||||||
assert loaded_shard_id < len(self.output_sizes)
|
assert loaded_shard_id < len(self.output_sizes)
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if output_dim is not None:
|
if output_dim is not None:
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||||
# Special case for quantization.
|
# Special case for quantization.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
@@ -545,10 +616,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
# bitsandbytes loads the weights of the specific portion
|
# bitsandbytes loads the weights of the specific portion
|
||||||
# no need to narrow here
|
# no need to narrow here
|
||||||
if not use_bitsandbytes_4bit:
|
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||||
# Special case for AQLM codebooks.
|
# Special case for AQLM codebooks.
|
||||||
elif is_metadata:
|
elif is_metadata:
|
||||||
@@ -572,7 +643,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
"the same for all partitions."
|
"the same for all partitions."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert (
|
||||||
|
param_data.shape == loaded_weight.shape
|
||||||
|
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def _load_fused_module_from_checkpoint(
|
def _load_fused_module_from_checkpoint(
|
||||||
@@ -629,26 +702,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
assert loaded_shard_id < len(self.output_sizes)
|
assert loaded_shard_id < len(self.output_sizes)
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
if isinstance(param, BlockQuantScaleParameter):
|
if isinstance(param, BlockQuantScaleParameter):
|
||||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||||
shard_offset = (
|
shard_offset = (
|
||||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
||||||
) // tp_size
|
) // self.tp_size
|
||||||
shard_size = (
|
shard_size = (
|
||||||
(self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
|
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
||||||
|
// block_n
|
||||||
|
// self.tp_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||||
|
|
||||||
param.load_merged_column_weight(
|
param.load_merged_column_weight(
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
shard_id=loaded_shard_id,
|
shard_id=loaded_shard_id,
|
||||||
shard_offset=shard_offset,
|
shard_offset=shard_offset,
|
||||||
shard_size=shard_size,
|
shard_size=shard_size,
|
||||||
|
use_presharded_weights=self.use_presharded_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -689,6 +763,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -697,7 +773,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
total_num_kv_heads = total_num_heads
|
total_num_kv_heads = total_num_heads
|
||||||
self.total_num_kv_heads = total_num_kv_heads
|
self.total_num_kv_heads = total_num_kv_heads
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
if tp_rank is None:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
if tp_size is None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
if tp_size >= self.total_num_kv_heads:
|
if tp_size >= self.total_num_kv_heads:
|
||||||
self.num_kv_heads = 1
|
self.num_kv_heads = 1
|
||||||
@@ -724,6 +804,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
tp_size=tp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||||
@@ -814,13 +896,24 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||||
shard_size = (shard_size + block_n - 1) // block_n
|
shard_size = (shard_size + block_n - 1) // block_n
|
||||||
|
|
||||||
param.load_qkv_weight(
|
if isinstance(param, _ColumnvLLMParameter):
|
||||||
loaded_weight=loaded_weight,
|
load_column_qkv_weight(
|
||||||
num_heads=self.num_kv_head_replicas,
|
param,
|
||||||
shard_id=loaded_shard_id,
|
loaded_weight,
|
||||||
shard_offset=shard_offset,
|
num_heads=self.num_kv_head_replicas,
|
||||||
shard_size=shard_size,
|
shard_id=loaded_shard_id,
|
||||||
)
|
shard_offset=shard_offset,
|
||||||
|
shard_size=shard_size,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
param.load_qkv_weight(
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
num_heads=self.num_kv_head_replicas,
|
||||||
|
shard_id=loaded_shard_id,
|
||||||
|
shard_offset=shard_offset,
|
||||||
|
shard_size=shard_size,
|
||||||
|
)
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
self,
|
self,
|
||||||
@@ -840,12 +933,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if is_gguf_weight:
|
if is_gguf_weight:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
shard_size = loaded_weight.size(output_dim) // self.tp_size
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
|
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||||
|
|
||||||
@@ -872,7 +962,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
param_data, loaded_weight, 0
|
param_data, loaded_weight, 0
|
||||||
)
|
)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert (
|
||||||
|
param_data.shape == loaded_weight.shape
|
||||||
|
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
return
|
return
|
||||||
shard_offsets = [
|
shard_offsets = [
|
||||||
@@ -934,7 +1026,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
assert loaded_shard_id in ["q", "k", "v"]
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
|
|
||||||
# If output dim is defined, use the default loading process.
|
# If output dim is defined, use the default loading process.
|
||||||
@@ -984,9 +1075,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||||
if loaded_shard_id == "q":
|
if loaded_shard_id == "q":
|
||||||
shard_id = tp_rank
|
shard_id = self.tp_rank
|
||||||
else:
|
else:
|
||||||
shard_id = tp_rank // self.num_kv_head_replicas
|
shard_id = self.tp_rank // self.num_kv_head_replicas
|
||||||
start_idx = shard_id * shard_size
|
start_idx = shard_id * shard_size
|
||||||
|
|
||||||
# bitsandbytes loads the weights of the specific portion
|
# bitsandbytes loads the weights of the specific portion
|
||||||
@@ -1014,7 +1105,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
"for all partitions."
|
"for all partitions."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert (
|
||||||
|
param_data.shape == loaded_weight.shape
|
||||||
|
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
@@ -1055,6 +1148,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||||
@@ -1064,10 +1160,14 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
if tp_rank is None:
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
if tp_size is None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank, self.tp_size = tp_rank, tp_size
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
self.use_presharded_weights = use_presharded_weights
|
||||||
|
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
@@ -1101,8 +1201,6 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
input_dim = getattr(param, "input_dim", None)
|
input_dim = getattr(param, "input_dim", None)
|
||||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
|
||||||
@@ -1116,15 +1214,19 @@ class RowParallelLinear(LinearBase):
|
|||||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||||
weight_shape = list(loaded_weight.shape)
|
weight_shape = list(loaded_weight.shape)
|
||||||
if input_dim:
|
if input_dim:
|
||||||
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
|
||||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
# bitsandbytes loads the weights of the specific portion
|
# bitsandbytes loads the weights of the specific portion
|
||||||
# no need to narrow here
|
# no need to narrow here
|
||||||
if input_dim is not None and not use_bitsandbytes_4bit:
|
if (
|
||||||
|
input_dim is not None
|
||||||
|
and not use_bitsandbytes_4bit
|
||||||
|
and not self.use_presharded_weights
|
||||||
|
):
|
||||||
shard_size = param_data.shape[input_dim]
|
shard_size = param_data.shape[input_dim]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||||
|
|
||||||
# Special case for loading scales off disk, which often do not
|
# Special case for loading scales off disk, which often do not
|
||||||
@@ -1132,7 +1234,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert (
|
||||||
|
param_data.shape == loaded_weight.shape
|
||||||
|
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
|
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
|
||||||
@@ -1143,17 +1247,21 @@ class RowParallelLinear(LinearBase):
|
|||||||
assert loaded_weight.numel() == 1
|
assert loaded_weight.numel() == 1
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
load_row_parallel_weight(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
self.tp_rank,
|
||||||
|
use_presharded_weights=self.use_presharded_weights,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
splitted_input = split_tensor_along_last_dim(
|
splitted_input = split_tensor_along_last_dim(
|
||||||
input_, num_partitions=self.tp_size
|
input_, num_partitions=self.tp_size
|
||||||
)
|
)
|
||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
weight_loader=self.weight_loader,
|
weight_loader=self.weight_loader,
|
||||||
)
|
)
|
||||||
|
self.use_presharded_weights = use_presharded_weights
|
||||||
|
|
||||||
def _load_per_tensor_weight_scale(
|
def _load_per_tensor_weight_scale(
|
||||||
self,
|
self,
|
||||||
@@ -395,10 +397,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
weight_name: str,
|
weight_name: str,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
use_presharded_weights: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.use_presharded_weights = use_presharded_weights
|
|
||||||
|
|
||||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||||
# against known CompressionFormat enum values that have this quality
|
# against known CompressionFormat enum values that have this quality
|
||||||
|
|||||||
431
python/sglang/srt/layers/parameter.py
Normal file
431
python/sglang/srt/layers/parameter.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
"""
|
||||||
|
Adapted from vLLM (0.6.4.post1).
|
||||||
|
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BasevLLMParameter",
|
||||||
|
"PackedvLLMParameter",
|
||||||
|
"PerTensorScaleParameter",
|
||||||
|
"ModelWeightParameter",
|
||||||
|
"ChannelQuantScaleParameter",
|
||||||
|
"GroupQuantScaleParameter",
|
||||||
|
"PackedColumnParameter",
|
||||||
|
"RowvLLMParameter",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BasevLLMParameter(Parameter):
|
||||||
|
"""
|
||||||
|
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
|
||||||
|
by taking in a linear weight loader. Will copy the loaded weight
|
||||||
|
into the parameter when the provided weight loader is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, data: torch.Tensor, **kwargs):
|
||||||
|
|
||||||
|
return super().__new__(cls, data=data, requires_grad=False)
|
||||||
|
|
||||||
|
def __init__(self, data: torch.Tensor, weight_loader: Callable):
|
||||||
|
"""
|
||||||
|
Initialize the BasevLLMParameter
|
||||||
|
|
||||||
|
:param data: torch tensor with the parameter data
|
||||||
|
:param weight_loader: weight loader callable
|
||||||
|
|
||||||
|
:returns: a torch.nn.parameter
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._weight_loader = weight_loader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight_loader(self):
|
||||||
|
return self._weight_loader
|
||||||
|
|
||||||
|
def _assert_and_load(self, loaded_weight: torch.Tensor):
|
||||||
|
assert self.data.shape == loaded_weight.shape
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||||
|
self._assert_and_load(loaded_weight)
|
||||||
|
|
||||||
|
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||||
|
self._assert_and_load(loaded_weight)
|
||||||
|
|
||||||
|
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||||
|
self._assert_and_load(loaded_weight)
|
||||||
|
|
||||||
|
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||||
|
self._assert_and_load(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||||
|
"""
|
||||||
|
Private class defining weight loading functionality
|
||||||
|
(load_merged_column_weight, load_qkv_weight)
|
||||||
|
for parameters being loaded into linear layers with column
|
||||||
|
parallelism. This includes QKV and MLP layers which are
|
||||||
|
not already fused on disk. Requires an output dimension
|
||||||
|
to be defined. Called within the weight loader of
|
||||||
|
each of the column parallel linear layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, output_dim: int, **kwargs):
|
||||||
|
self._output_dim = output_dim
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self):
|
||||||
|
return self._output_dim
|
||||||
|
|
||||||
|
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.data.shape[self.output_dim]
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.output_dim, tp_rank * shard_size, shard_size
|
||||||
|
)
|
||||||
|
assert self.data.shape == loaded_weight.shape
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||||
|
|
||||||
|
shard_offset = kwargs.get("shard_offset")
|
||||||
|
shard_size = kwargs.get("shard_size")
|
||||||
|
use_presharded_weights = kwargs.get("use_presharded_weights")
|
||||||
|
if (
|
||||||
|
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||||
|
and self.packed_dim == self.output_dim
|
||||||
|
):
|
||||||
|
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||||
|
shard_offset=shard_offset, shard_size=shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
param_data = self.data
|
||||||
|
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||||
|
if not use_presharded_weights:
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.output_dim, tp_rank * shard_size, shard_size
|
||||||
|
)
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||||
|
|
||||||
|
shard_offset = kwargs.get("shard_offset")
|
||||||
|
shard_size = kwargs.get("shard_size")
|
||||||
|
shard_id = kwargs.get("shard_id")
|
||||||
|
num_heads = kwargs.get("num_heads")
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||||
|
and self.output_dim == self.packed_dim
|
||||||
|
):
|
||||||
|
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||||
|
shard_offset=shard_offset, shard_size=shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
param_data = self.data
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||||
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.output_dim, shard_id * shard_size, shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class RowvLLMParameter(BasevLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class defining weight_loading functionality
|
||||||
|
(load_row_parallel_weight) for parameters being loaded
|
||||||
|
into linear layers with row parallel functionality.
|
||||||
|
Requires an input_dim to be defined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, **kwargs):
|
||||||
|
self._input_dim = input_dim
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dim(self):
|
||||||
|
return self._input_dim
|
||||||
|
|
||||||
|
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||||
|
use_presharded_weights = kwargs.get("use_presharded_weights")
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.data.shape[self.input_dim]
|
||||||
|
if not use_presharded_weights:
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
self.input_dim, tp_rank * shard_size, shard_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(loaded_weight.shape) == 0:
|
||||||
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
|
assert self.data.shape == loaded_weight.shape
|
||||||
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class for linear layer weights. Uses both column and
|
||||||
|
row parallelism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class for weight scales loaded for weights with
|
||||||
|
grouped quantization. Uses both column and row parallelism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class for weight scales loaded for weights with
|
||||||
|
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PerTensorScaleParameter(BasevLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class for scales where the number of scales is
|
||||||
|
equivalent to the number of logical matrices in fused linear
|
||||||
|
layers (e.g. for QKV, there are 3 scales loaded from disk).
|
||||||
|
This is relevant to weights with per-tensor quantization.
|
||||||
|
Adds functionality to map the scalers to a shard during
|
||||||
|
weight loading.
|
||||||
|
|
||||||
|
Note: additional parameter manipulation may be handled
|
||||||
|
for each quantization config specifically, within
|
||||||
|
process_weights_after_loading
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||||
|
if isinstance(shard_id, int):
|
||||||
|
return shard_id
|
||||||
|
|
||||||
|
# if not int, assume shard_id for qkv
|
||||||
|
# map to int and return
|
||||||
|
assert isinstance(shard_id, str)
|
||||||
|
assert shard_id in self.qkv_idxs
|
||||||
|
return self.qkv_idxs[shard_id]
|
||||||
|
|
||||||
|
# For row parallel layers, no sharding needed
|
||||||
|
# load weight into parameter as is
|
||||||
|
def load_row_parallel_weight(self, *args, **kwargs):
|
||||||
|
super().load_row_parallel_weight(*args, **kwargs)
|
||||||
|
|
||||||
|
def load_merged_column_weight(self, *args, **kwargs):
|
||||||
|
self._load_into_shard_id(*args, **kwargs)
|
||||||
|
|
||||||
|
def load_qkv_weight(self, *args, **kwargs):
|
||||||
|
self._load_into_shard_id(*args, **kwargs)
|
||||||
|
|
||||||
|
def load_column_parallel_weight(self, *args, **kwargs):
|
||||||
|
super().load_row_parallel_weight(*args, **kwargs)
|
||||||
|
|
||||||
|
def _load_into_shard_id(
|
||||||
|
self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Slice the parameter data based on the shard id for
|
||||||
|
loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
param_data = self.data
|
||||||
|
shard_id = self._shard_id_as_int(shard_id)
|
||||||
|
|
||||||
|
# AutoFP8 scales do not have a shape
|
||||||
|
# compressed-tensors scales do have a shape
|
||||||
|
if len(loaded_weight.shape) != 0:
|
||||||
|
assert loaded_weight.shape[0] == 1
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
|
||||||
|
param_data = param_data[shard_id]
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class PackedColumnParameter(_ColumnvLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter for model parameters which are packed on disk
|
||||||
|
and support column parallelism only. See PackedvLLMParameter
|
||||||
|
for more details on the packed properties.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
packed_factor: Union[int, Fraction],
|
||||||
|
packed_dim: int,
|
||||||
|
marlin_tile_size: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self._packed_factor = packed_factor
|
||||||
|
self._packed_dim = packed_dim
|
||||||
|
self._marlin_tile_size = marlin_tile_size
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def packed_dim(self):
|
||||||
|
return self._packed_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def packed_factor(self):
|
||||||
|
return self._packed_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def marlin_tile_size(self):
|
||||||
|
return self._marlin_tile_size
|
||||||
|
|
||||||
|
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||||
|
return _adjust_shard_indexes_for_packing(
|
||||||
|
shard_size=shard_size,
|
||||||
|
shard_offset=shard_offset,
|
||||||
|
packed_factor=self.packed_factor,
|
||||||
|
marlin_tile_size=self.marlin_tile_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PackedvLLMParameter(ModelWeightParameter):
|
||||||
|
"""
|
||||||
|
Parameter for model weights which are packed on disk.
|
||||||
|
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
|
||||||
|
Extends the ModelWeightParameter to take in the
|
||||||
|
packed factor, the packed dimension, and optionally, marlin
|
||||||
|
tile size for marlin kernels. Adjusts the shard_size and
|
||||||
|
shard_offset for fused linear layers model weight loading
|
||||||
|
by accounting for packing and optionally, marlin tile size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
packed_factor: Union[int, Fraction],
|
||||||
|
packed_dim: int,
|
||||||
|
marlin_tile_size: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self._packed_factor = packed_factor
|
||||||
|
self._packed_dim = packed_dim
|
||||||
|
self._marlin_tile_size = marlin_tile_size
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def packed_dim(self):
|
||||||
|
return self._packed_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def packed_factor(self):
|
||||||
|
return self._packed_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def marlin_tile_size(self):
|
||||||
|
return self._marlin_tile_size
|
||||||
|
|
||||||
|
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||||
|
return _adjust_shard_indexes_for_packing(
|
||||||
|
shard_size=shard_size,
|
||||||
|
shard_offset=shard_offset,
|
||||||
|
packed_factor=self.packed_factor,
|
||||||
|
marlin_tile_size=self.marlin_tile_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def permute_param_layout_(
|
||||||
|
param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs
|
||||||
|
) -> BasevLLMParameter:
|
||||||
|
"""
|
||||||
|
Permute a parameter's layout to the specified input and output dimensions,
|
||||||
|
useful for forcing the parameter into a known layout, for example, if I need
|
||||||
|
a packed (quantized) weight matrix to be in the layout
|
||||||
|
{input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
|
then I can call:
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||||
|
required, asserting if it cannot get it to the correct layout)
|
||||||
|
"""
|
||||||
|
|
||||||
|
curr_input_dim = getattr(param, "input_dim", None)
|
||||||
|
curr_output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
|
if curr_input_dim is None or curr_output_dim is None:
|
||||||
|
assert param.data.dim() == 2, (
|
||||||
|
"permute_param_layout_ only supports 2D parameters when either "
|
||||||
|
"input_dim or output_dim is not set"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if one of the dimensions is not set, set it to the opposite of the other
|
||||||
|
# we can only do this since we asserted the parameter is 2D above
|
||||||
|
if curr_input_dim is None:
|
||||||
|
assert curr_output_dim is not None, "either input or output dim must be set"
|
||||||
|
curr_input_dim = (curr_output_dim + 1) % 2
|
||||||
|
if curr_output_dim is None:
|
||||||
|
assert curr_input_dim is not None, "either input or output dim must be set"
|
||||||
|
curr_output_dim = (curr_input_dim + 1) % 2
|
||||||
|
|
||||||
|
# create permutation from the current layout to the layout with
|
||||||
|
# self.input_dim at input_dim and self.output_dim at output_dim preserving
|
||||||
|
# other dimensions
|
||||||
|
perm = [
|
||||||
|
i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]
|
||||||
|
]
|
||||||
|
perm.insert(input_dim, curr_input_dim)
|
||||||
|
perm.insert(output_dim, curr_output_dim)
|
||||||
|
|
||||||
|
if "packed_dim" in kwargs:
|
||||||
|
assert (
|
||||||
|
hasattr(param, "packed_dim")
|
||||||
|
and param.packed_dim == perm[kwargs["packed_dim"]]
|
||||||
|
), "permute_param_layout_ currently doesn't support repacking"
|
||||||
|
|
||||||
|
param.data = param.data.permute(*perm)
|
||||||
|
if hasattr(param, "_input_dim"):
|
||||||
|
param._input_dim = input_dim
|
||||||
|
if hasattr(param, "_output_dim"):
|
||||||
|
param._output_dim = output_dim
|
||||||
|
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
|
||||||
|
param._packed_dim = kwargs["packed_dim"]
|
||||||
|
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size):
|
||||||
|
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||||
|
|
||||||
|
|
||||||
|
def _adjust_shard_indexes_for_packing(
|
||||||
|
shard_size, shard_offset, packed_factor, marlin_tile_size
|
||||||
|
):
|
||||||
|
shard_size = shard_size // packed_factor
|
||||||
|
shard_offset = shard_offset // packed_factor
|
||||||
|
if marlin_tile_size is not None:
|
||||||
|
return _adjust_shard_indexes_for_marlin(
|
||||||
|
shard_size=shard_size,
|
||||||
|
shard_offset=shard_offset,
|
||||||
|
marlin_tile_size=marlin_tile_size,
|
||||||
|
)
|
||||||
|
return shard_size, shard_offset
|
||||||
@@ -25,9 +25,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from vllm.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import BasevLLMParameter
|
|
||||||
|
|
||||||
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class Session:
|
|||||||
|
|
||||||
if last_req is not None:
|
if last_req is not None:
|
||||||
# trim bos token if it is an append
|
# trim bos token if it is an append
|
||||||
if req.input_ids[0] == tokenizer.bos_token_id:
|
if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
|
||||||
req.input_ids = req.input_ids[1:]
|
req.input_ids = req.input_ids[1:]
|
||||||
|
|
||||||
input_ids = (
|
input_ids = (
|
||||||
|
|||||||
@@ -106,6 +106,9 @@ class ForwardMode(IntEnum):
|
|||||||
def is_dummy_first(self):
|
def is_dummy_first(self):
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|
||||||
|
def is_decode_or_idle(self):
|
||||||
|
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||||
|
|
||||||
|
|
||||||
class CaptureHiddenMode(IntEnum):
|
class CaptureHiddenMode(IntEnum):
|
||||||
NULL = auto()
|
NULL = auto()
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class ModelRunner:
|
|||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
elif self.device == "xpu":
|
elif self.device == "xpu":
|
||||||
# TODO(liangan1):Just use gloo to bypass the initilization fail
|
# TODO(liangan1): Just use gloo to bypass the initilization fail
|
||||||
# Need to use xccl for xpu backend in the future
|
# Need to use xccl for xpu backend in the future
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
elif self.device == "hpu":
|
elif self.device == "hpu":
|
||||||
@@ -634,7 +634,6 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_double_sparsity_channel_config(self, selected_channel):
|
def init_double_sparsity_channel_config(self, selected_channel):
|
||||||
|
|
||||||
selected_channel = "." + selected_channel + "_proj"
|
selected_channel = "." + selected_channel + "_proj"
|
||||||
self.sorted_channels = []
|
self.sorted_channels = []
|
||||||
# load channel config
|
# load channel config
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
reduce_results=reduce_results,
|
reduce_results=reduce_results,
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
)
|
)
|
||||||
self.act_fn = GeluAndMul(approximate="tanh")
|
self.act_fn = GeluAndMul(approximate="tanh")
|
||||||
|
|
||||||
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
|
|||||||
renormalize=False,
|
renormalize=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
|
|||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
reduce_results: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
reduce_results=reduce_results,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
self.self_attn = Grok1Attention(
|
self.self_attn = Grok1Attention(
|
||||||
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
),
|
),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
)
|
)
|
||||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Grok1DecoderLayer(config, i, quant_config=quant_config)
|
Grok1DecoderLayer(
|
||||||
|
config,
|
||||||
|
i,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
|
)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Grok1Model(config, quant_config=quant_config)
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
|
||||||
|
|
||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
|
||||||
if (
|
if (
|
||||||
self.config.num_local_experts > 0
|
self.config.num_local_experts > 0
|
||||||
and get_tensor_model_parallel_world_size() > 1
|
and get_tensor_model_parallel_world_size() > 1
|
||||||
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.use_presharded_weights = False
|
self.use_presharded_weights = False
|
||||||
|
|
||||||
|
self.model = Grok1Model(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_presharded_weights=self.use_presharded_weights,
|
||||||
|
)
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
self,
|
self,
|
||||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
use_presharded_weights: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
if use_presharded_weights is None:
|
|
||||||
use_presharded_weights = self.use_presharded_weights
|
|
||||||
num_experts = self.config.num_local_experts
|
num_experts = self.config.num_local_experts
|
||||||
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
if use_presharded_weights:
|
|
||||||
extra_kwargs = {
|
|
||||||
"use_presharded_weights": use_presharded_weights
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
extra_kwargs = {}
|
|
||||||
|
|
||||||
load_weight_wrapper(
|
load_weight_wrapper(
|
||||||
name,
|
name,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
**extra_kwargs,
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -544,7 +544,12 @@ def launch_server(
|
|||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
target=_wait_and_warmup,
|
||||||
|
args=(
|
||||||
|
server_args,
|
||||||
|
pipe_finish_writer,
|
||||||
|
tokenizer_manager.image_token_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
@@ -614,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||||
headers = {}
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
if server_args.api_key:
|
if server_args.api_key:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
|||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ nvidia-smi
|
|||||||
kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
||||||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
||||||
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
||||||
|
kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2}') 2>/dev/null
|
||||||
|
|
||||||
# Clean all GPU processes if any argument is provided
|
# Clean all GPU processes if any argument is provided
|
||||||
if [ $# -gt 0 ]; then
|
if [ $# -gt 0 ]; then
|
||||||
|
|||||||
Reference in New Issue
Block a user