From 1dce6c480fac4aa20d40f133d220c5be4605e093 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Fri, 4 Jul 2025 00:51:38 +0800 Subject: [PATCH] [CPU] support the case where num_attention_heads or intermediate_size is not divisible by the TP size (#6771) --- python/sglang/srt/configs/update_config.py | 119 ++++++++++++++++++ python/sglang/srt/layers/linear.py | 91 ++++++++++++-- .../srt/layers/moe/fused_moe_triton/layer.py | 42 +++++-- python/sglang/srt/layers/parameter.py | 74 +++++++++-- .../srt/layers/vocab_parallel_embedding.py | 10 +- python/sglang/srt/managers/scheduler.py | 12 +- .../sglang/srt/model_executor/model_runner.py | 7 +- .../sglang/srt/model_loader/weight_utils.py | 54 ++++++++ python/sglang/srt/models/mllama4.py | 20 +-- python/sglang/srt/models/qwen2.py | 8 +- python/sglang/srt/utils.py | 2 + 11 files changed, 399 insertions(+), 40 deletions(-) create mode 100644 python/sglang/srt/configs/update_config.py diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py new file mode 100644 index 000000000..f9e6d15a8 --- /dev/null +++ b/python/sglang/srt/configs/update_config.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +DEFAULT_MOE_PADDING_SIZE = 32 + + +if TYPE_CHECKING: + from sglang.srt.configs.load_config import LoadConfig + from sglang.srt.configs.model_config import ModelConfig + + +def may_get_weight_block_size(model_config, load_config): + from sglang.srt.model_loader.loader import _get_quantization_config + from sglang.srt.model_loader.utils import get_model_architecture + + model_class, _ = get_model_architecture(model_config) + packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {}) + + quant_config = _get_quantization_config( + model_config, load_config, packed_modules_mapping + ) + + if quant_config is not None and hasattr(quant_config, "weight_block_size"): + return getattr(quant_config, "weight_block_size") + return None + + +def get_moe_padding_size(weight_block_size): + if weight_block_size is not None: + # See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + assert ( + len(weight_block_size) == 2 + ), "Only len(weight_block_size) == 2 is supported" + assert ( + weight_block_size[0] == weight_block_size[1] + ), "Only weight_block_size[0] == weight_block_size[1] is supported" + + return weight_block_size[0] + + return DEFAULT_MOE_PADDING_SIZE + + +def get_num_heads_padding_size(tp_size, weight_block_size): + pad_size = ( + tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size + ) + return pad_size + + +def update_intermediate_size(model_config, attr_name, intermediate_padding_size): + if hasattr(model_config.hf_config, attr_name): + attr_value = getattr(model_config.hf_config, attr_name) + if attr_value % intermediate_padding_size != 0: + from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size + + attr_value = pad_vocab_size(attr_value, intermediate_padding_size) + setattr(model_config.hf_config, attr_name, attr_value) + setattr(model_config.hf_text_config, attr_name, attr_value) + return model_config + + +def adjust_config_with_unaligned_cpu_tp( + model_config: ModelConfig, load_config: LoadConfig, tp_size: int +) -> ModelConfig: + # Support the case where the num_attention_heads is not divisible by the TP size. + weight_block_size = may_get_weight_block_size(model_config, load_config) + + model_config.hf_config.original_num_attention_heads = ( + model_config.num_attention_heads + ) + model_config.hf_text_config.original_num_attention_heads = ( + model_config.num_attention_heads + ) + + model_config.hf_config.original_total_num_kv_heads = ( + model_config.get_total_num_kv_heads() + ) + model_config.hf_text_config.original_total_num_kv_heads = ( + model_config.get_total_num_kv_heads() + ) + + if ( + model_config.num_attention_heads % tp_size != 0 + or model_config.get_total_num_kv_heads() % tp_size != 0 + ): + # Compute the head_dim using the model_config.num_attention_heads before padding + if not hasattr(model_config.hf_config, "head_dim"): + model_config.hf_config.head_dim = ( + model_config.hidden_size // model_config.num_attention_heads + ) + + query_heads_per_kv = ( + model_config.num_attention_heads // model_config.get_total_num_kv_heads() + ) + total_kv_heads = model_config.get_total_num_kv_heads() + from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size + + pad_size = get_num_heads_padding_size(tp_size, weight_block_size) + num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size) + + model_config.num_key_value_heads = num_key_value_heads + model_config.hf_config.num_key_value_heads = num_key_value_heads + model_config.hf_text_config.num_key_value_heads = num_key_value_heads + + num_attention_heads = num_key_value_heads * query_heads_per_kv + model_config.num_attention_heads = num_attention_heads + model_config.hf_config.num_attention_heads = num_attention_heads + model_config.hf_text_config.num_attention_heads = num_attention_heads + + intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size) + model_config = update_intermediate_size( + model_config, "moe_intermediate_size", intermediate_padding_size + ) + model_config = update_intermediate_size( + model_config, "intermediate_size", intermediate_padding_size + ) + + return model_config diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 1fc43b8b6..3fa012ce8 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -426,8 +426,26 @@ class ColumnParallelLinear(LinearBase): if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] start_idx = self.tp_rank * shard_size - if not self.use_presharded_weights: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + if _is_cpu: + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, + ) + + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + start_idx, + output_dim, + shard_size, + not self.use_presharded_weights, + ) + else: + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size + ) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -644,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = self.tp_rank * shard_size - # bitsandbytes loads the weights of the specific portion - # no need to narrow here - if not use_bitsandbytes_4bit and not self.use_presharded_weights: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + if _is_cpu: + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, + ) + + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + start_idx, + output_dim, + shard_size, + not use_bitsandbytes_4bit and not self.use_presharded_weights, + ) + else: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit and not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size + ) + # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -1112,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear): shard_id = self.tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - # bitsandbytes loads the weights of the specific portion - # no need to narrow here - if not use_bitsandbytes_4bit and not self.use_presharded_weights: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + if _is_cpu: + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, + ) + + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + start_idx, + output_dim, + shard_size, + not use_bitsandbytes_4bit and not self.use_presharded_weights, + ) + else: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit and not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size + ) # Special case for for AQLM codebooks. elif is_metadata: @@ -1257,7 +1311,22 @@ class RowParallelLinear(LinearBase): ): shard_size = param_data.shape[input_dim] start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + if _is_cpu: + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, + ) + + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + start_idx, + input_dim, + shard_size, + ) + else: + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 9147136e3..997297be6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -573,11 +574,6 @@ class FusedMoE(torch.nn.Module): # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - if not self.use_presharded_weights: - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) - # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. # w3, up_proj: Load into second logical weight of w13. @@ -588,7 +584,24 @@ class FusedMoE(torch.nn.Module): start = shard_size else: start = 0 - expert_data = expert_data.narrow(shard_dim, start, shard_size) + + if _is_cpu: + expert_data, loaded_weight = narrow_padded_param_and_loaded_weight( + expert_data, + loaded_weight, + start, + shard_size * tp_rank, + shard_dim, + shard_size, + not self.use_presharded_weights, + ) + else: + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + + expert_data = expert_data.narrow(shard_dim, start, shard_size) expert_data.copy_(loaded_weight) def _load_w2( @@ -605,10 +618,21 @@ class FusedMoE(torch.nn.Module): # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - if not self.use_presharded_weights: - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size + if _is_cpu: + expert_data, loaded_weight = narrow_padded_param_and_loaded_weight( + expert_data, + loaded_weight, + 0, # param_data_start + shard_size * tp_rank, + shard_dim, + shard_size, + not self.use_presharded_weights, ) + else: + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 978ec0ad0..d0ba43326 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -7,6 +7,8 @@ from typing import Callable, Optional, Union import torch from torch.nn import Parameter +from sglang.srt.utils import is_cpu + __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", @@ -21,6 +23,8 @@ __all__ = [ logger = logging.getLogger(__name__) +_is_cpu = is_cpu() + class BasevLLMParameter(Parameter): """ @@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter): ): if not use_presharded_weights: shard_size = self.data.shape[self.output_dim] - loaded_weight = loaded_weight.narrow( - self.output_dim, tp_rank * shard_size, shard_size + + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, ) + + if _is_cpu: + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + self.data, + loaded_weight, + 0, # param_data_start + tp_rank * shard_size, + self.output_dim, + shard_size, + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + else: + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) @@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter): param_data = self.data param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) - if not use_presharded_weights: - loaded_weight = loaded_weight.narrow( - self.output_dim, tp_rank * shard_size, shard_size + + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, + ) + + if _is_cpu: + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + 0, # param_data_start + tp_rank * shard_size, + self.output_dim, + shard_size, + not use_presharded_weights, ) + else: + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter): ): if not use_presharded_weights: shard_size = self.data.shape[self.input_dim] - loaded_weight = loaded_weight.narrow( - self.input_dim, tp_rank * shard_size, shard_size + + from sglang.srt.model_loader.weight_utils import ( + narrow_padded_param_and_loaded_weight, ) + if _is_cpu: + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + self.data, + loaded_weight, + 0, # param_data_start + tp_rank * shard_size, + self.input_dim, + shard_size, + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + return + else: + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 8e31a621c..d7056f5e0 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -246,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module): self.tp_size = 1 self.num_embeddings = num_embeddings - self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings + + # Support the case where the vocab size is not divisible by the TP size. + if ( + _is_cpu + and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0 + ): + padding_size *= self.tp_size + self.padding_size = padding_size + num_added_embeddings = num_embeddings - self.org_vocab_size self.use_presharded_weights = use_presharded_weights if use_presharded_weights: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8e910c0ee..faf030aab 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -149,6 +149,7 @@ from sglang.srt.utils import ( get_available_gpu_memory, get_bool_env_var, get_zmq_socket, + is_cpu, kill_itself_when_parent_died, point_to_point_pyobj, pyspy_dump_schedulers, @@ -167,6 +168,8 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) +_is_cpu = is_cpu() + @dataclass class GenerationBatchResult: @@ -2115,11 +2118,14 @@ class Scheduler( "kvcache": round( self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2 ), - "cuda_graph": round( - self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2 - ), "token_capacity": int(self.max_total_num_tokens), } + + if not _is_cpu: + ret["memory_usage"]["cuda_graph"] = round( + self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2 + ) + if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: ret["avg_spec_accept_length"] = ( self.cum_spec_accept_length / self.cum_spec_accept_count diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index de976c9af..4ff6bc18d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -29,6 +29,7 @@ import torch.distributed as dist from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.distributed import ( get_tp_group, @@ -165,7 +166,6 @@ class ModelRunner: token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, ): # Parse args - self.model_config = model_config self.mem_fraction_static = mem_fraction_static self.device = server_args.device self.gpu_id = gpu_id @@ -178,6 +178,7 @@ class ModelRunner: self.dp_size = server_args.dp_size self.pp_rank = pp_rank self.pp_size = pp_size + self.model_config = model_config self.dist_port = nccl_port self.server_args = server_args self.is_draft_worker = is_draft_worker @@ -604,6 +605,10 @@ class ModelRunner: download_dir=self.server_args.download_dir, model_loader_extra_config=self.server_args.model_loader_extra_config, ) + if self.device == "cpu": + self.model_config = adjust_config_with_unaligned_cpu_tp( + self.model_config, self.load_config, self.tp_size + ) if self.server_args.load_format == "gguf": monkey_patch_vllm_gguf_config() diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index db5e3b3cb..b3cf18ec9 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -961,3 +961,57 @@ def kv_cache_scales_loader( tp_rank, ) return [] + + +def get_actual_shard_size(shard_size, weight_start, weight_end): + if weight_end < weight_start: + return 0 + + return min(shard_size, weight_end - weight_start) + + +def reset_param_data_if_needed(param_data, dim, start, length): + if length == 0: + return + + assert length > 0, f"Length should be positive, but got {length}" + + param_data.narrow(dim, start, length).zero_() + return + + +def narrow_padded_param_and_loaded_weight( + param_data, + loaded_weight, + param_data_start, + weight_start, + dim, + shard_size, + narrow_weight=True, +): + actual_shard_size = get_actual_shard_size( + shard_size, weight_start, loaded_weight.size(dim) + ) + + if narrow_weight: + if actual_shard_size > 0: + loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size) + else: + # No real data to load; create a dummy tensor filled with zeros + loaded_weight = torch.zeros_like( + param_data.narrow(dim, param_data_start, actual_shard_size) + ) + + # [Note] Reset padded weights to zero. + # If the actual shard size is less than the shard size, we need to reset + # the padded param_data to zero and then copy the loaded_weight into it. + reset_param_data_if_needed( + param_data, + dim, + param_data_start + actual_shard_size, + shard_size - actual_shard_size, + ) + + param_data = param_data.narrow(dim, param_data_start, actual_shard_size) + + return param_data, loaded_weight diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 5f7c0c006..73d1a0068 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -16,7 +16,9 @@ from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, is_cpu + +_is_cpu = is_cpu() class Llama4ForConditionalGeneration(nn.Module): @@ -107,13 +109,17 @@ class Llama4ForConditionalGeneration(nn.Module): # rotary embeds should be sliced if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight": - loaded_weight = permute( - loaded_weight, self.language_model.config.num_key_value_heads - ) + if _is_cpu: + dim = self.language_model.config.original_total_num_kv_heads + else: + dim = self.language_model.config.num_key_value_heads + loaded_weight = permute(loaded_weight, dim) elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight": - loaded_weight = permute( - loaded_weight, self.language_model.config.num_attention_heads - ) + if _is_cpu: + dim = self.language_model.config.original_num_attention_heads + else: + dim = self.language_model.config.num_attention_heads + loaded_weight = permute(loaded_weight, dim) return name, loaded_weight diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 10ac84ecc..714d53fe6 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -100,6 +100,7 @@ class Qwen2Attention(nn.Module): hidden_size: int, num_heads: int, num_kv_heads: int, + head_dim: Optional[int] = None, layer_id: int = 0, rope_theta: float = 1000000, rope_scaling: Optional[Dict[str, Any]] = None, @@ -123,7 +124,10 @@ class Qwen2Attention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + if head_dim is not None: + self.head_dim = head_dim + else: + self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -191,10 +195,12 @@ class Qwen2DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + head_dim = getattr(config, "head_dim", None) self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, + head_dim=head_dim, layer_id=layer_id, rope_theta=rope_theta, rope_scaling=rope_scaling, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5761aa8f4..e6cb7debc 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -13,6 +13,8 @@ # ============================================================================== """Common utilities.""" +from __future__ import annotations + import base64 import builtins import ctypes