fix black in pre-commit (#1940)
This commit is contained in:
@@ -39,7 +39,7 @@ class ModelConfig:
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
is_embedding: Optional[bool] = None
|
||||
is_embedding: Optional[bool] = None,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
@@ -52,7 +52,9 @@ class ModelConfig:
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
|
||||
|
||||
@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
|
||||
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
has been changed from the base implementation.
|
||||
"""
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
||||
None)
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
|
||||
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
||||
|
||||
return (class_embedding is not None
|
||||
and class_embedding is not base_embedding)
|
||||
return class_embedding is not None and class_embedding is not base_embedding
|
||||
|
||||
@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int,
|
||||
rank: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
per_partition_vocab_size: int, rank: int, offset: int = 0
|
||||
) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f + offset, index_l + offset
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
def vocab_range_from_global_vocab_size(
|
||||
global_vocab_size: int, rank: int, world_size: int, offset: int = 0
|
||||
) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank,
|
||||
offset=offset)
|
||||
return vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, offset=offset
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices:
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return (self.padded_org_vocab_end_index -
|
||||
self.padded_org_vocab_start_index)
|
||||
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return (self.padded_added_vocab_end_index -
|
||||
self.padded_added_vocab_start_index)
|
||||
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices:
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert (self.padded_org_vocab_start_index <=
|
||||
self.padded_org_vocab_end_index)
|
||||
assert (self.padded_added_vocab_start_index <=
|
||||
self.padded_added_vocab_end_index)
|
||||
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
|
||||
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert (self.added_vocab_start_index <=
|
||||
self.padded_added_vocab_start_index)
|
||||
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices:
|
||||
|
||||
@torch.jit.script
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
input_: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.jit.script will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
input_ < added_vocab_end_index
|
||||
)
|
||||
added_offset = (
|
||||
added_vocab_start_index
|
||||
- (org_vocab_end_index - org_vocab_start_index)
|
||||
- num_org_vocab_padding
|
||||
)
|
||||
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
|
||||
added_offset * added_vocab_mask
|
||||
)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_tp = enable_tp
|
||||
@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.org_vocab_size_padded = pad_vocab_size(
|
||||
self.org_vocab_size, self.padding_size
|
||||
)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
|
||||
)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.shard_indices = self._get_indices(
|
||||
self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
tp_rank,
|
||||
self.tp_size,
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
type(linear_method)
|
||||
)
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
||||
)
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
|
||||
@@ -260,53 +278,68 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_embeddings_per_partition = divide(
|
||||
self.num_embeddings_padded, self.tp_size
|
||||
)
|
||||
assert (
|
||||
self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
|
||||
)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.shard_indices.org_vocab_end_index
|
||||
- self.shard_indices.org_vocab_start_index
|
||||
)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
self.shard_indices.added_vocab_end_index
|
||||
- self.shard_indices.added_vocab_start_index
|
||||
)
|
||||
|
||||
self.linear_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.linear_method.create_weights(
|
||||
self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
||||
def _get_indices(
|
||||
cls,
|
||||
vocab_size_padded: int,
|
||||
org_vocab_size_padded: int,
|
||||
vocab_size: int,
|
||||
org_vocab_size: int,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
||||
tp_size))
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
|
||||
)
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
offset=org_vocab_size))
|
||||
vocab_range_from_global_vocab_size(
|
||||
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
|
||||
)
|
||||
)
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
||||
org_vocab_size)
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
||||
vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
padded_org_vocab_start_index,
|
||||
padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index,
|
||||
padded_added_vocab_end_index,
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
added_vocab_start_index,
|
||||
added_vocab_end_index,
|
||||
)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
added_embeddings: List[int] = []
|
||||
padding: List[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
shard_indices = self._get_indices(
|
||||
self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size,
|
||||
tp_rank,
|
||||
self.tp_size,
|
||||
)
|
||||
range_start = self.num_embeddings_per_partition * tp_rank
|
||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
||||
base_embeddings.extend(
|
||||
range(range_start,
|
||||
range_start + shard_indices.num_org_elements))
|
||||
range(range_start, range_start + shard_indices.num_org_elements)
|
||||
)
|
||||
padding.extend(
|
||||
range(range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded))
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
)
|
||||
)
|
||||
added_embeddings.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements))
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements,
|
||||
)
|
||||
)
|
||||
padding.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded))
|
||||
assert (range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded == range_end)
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements,
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements_padded,
|
||||
)
|
||||
)
|
||||
assert (
|
||||
range_start
|
||||
+ shard_indices.num_org_elements_padded
|
||||
+ shard_indices.num_added_elements_padded
|
||||
== range_end
|
||||
)
|
||||
ret = base_embeddings + added_embeddings + padding
|
||||
assert len(ret) == self.num_embeddings_padded
|
||||
return ret
|
||||
@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
packed_factor = param.packed_factor if isinstance(
|
||||
param, BasevLLMParameter) else param.pack_factor
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||
param.packed_factor)
|
||||
packed_factor = (
|
||||
param.packed_factor
|
||||
if isinstance(param, BasevLLMParameter)
|
||||
else param.pack_factor
|
||||
)
|
||||
assert loaded_weight.shape[output_dim] == (
|
||||
self.org_vocab_size // param.packed_factor
|
||||
)
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Copy the data.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0] :].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
input_,
|
||||
self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
self.shard_indices.added_vocab_end_index,
|
||||
)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.linear_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
||||
s += f", embedding_dim={self.embedding_dim}"
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f', num_embeddings_padded={self.num_embeddings_padded}'
|
||||
s += f", num_embeddings_padded={self.num_embeddings_padded}"
|
||||
if self.enable_tp:
|
||||
s += f', tp_size={self.tp_size}'
|
||||
s += f", tp_size={self.tp_size}"
|
||||
return s
|
||||
|
||||
|
||||
@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
@@ -86,8 +86,10 @@ class GenerateReqInput:
|
||||
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
else: # isinstance(self.sampling_params, list):
|
||||
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
||||
assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
|
||||
"The parallel_sample_num should be the same for all samples in sample params.")
|
||||
assert all(
|
||||
self.parallel_sample_num == sampling_params.get("n", 1)
|
||||
for sampling_params in self.sampling_params
|
||||
), "The parallel_sample_num should be the same for all samples in sample params."
|
||||
|
||||
if self.parallel_sample_num > 1 and self.is_single:
|
||||
self.is_single = False
|
||||
|
||||
@@ -911,8 +911,7 @@ class ScheduleBatch:
|
||||
keep_indices = [
|
||||
i
|
||||
for i in range(len(self.reqs))
|
||||
if not self.reqs[i].finished()
|
||||
and self.reqs[i] is not being_chunked_req
|
||||
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
||||
]
|
||||
|
||||
if keep_indices is None or len(keep_indices) == 0:
|
||||
@@ -1043,6 +1042,7 @@ class ScheduleBatch:
|
||||
for req in self.reqs:
|
||||
req.started_time = time.time()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The batch id
|
||||
|
||||
@@ -224,8 +224,8 @@ class Scheduler:
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time() # time of last stats for every iter
|
||||
self.last_log_tic = time.time() # time of last log for print decode log
|
||||
self.last_stats_tic = time.time() # time of last stats for every iter
|
||||
self.last_log_tic = time.time() # time of last log for print decode log
|
||||
self.stream_interval = server_args.stream_interval
|
||||
|
||||
# Init chunked prefill
|
||||
@@ -566,9 +566,7 @@ class Scheduler:
|
||||
and not self.last_batch.is_empty()
|
||||
):
|
||||
if self.being_chunked_req:
|
||||
self.last_batch.filter_batch(
|
||||
being_chunked_req=self.being_chunked_req
|
||||
)
|
||||
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
||||
# Inflight request keeps its rid but will get a new req_pool_idx.
|
||||
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
||||
@@ -628,9 +626,7 @@ class Scheduler:
|
||||
has_inflight = self.being_chunked_req is not None
|
||||
if has_inflight:
|
||||
self.being_chunked_req.init_next_round_input()
|
||||
self.being_chunked_req = adder.add_inflight_req(
|
||||
self.being_chunked_req
|
||||
)
|
||||
self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
|
||||
|
||||
if self.lora_paths:
|
||||
lora_set = (
|
||||
@@ -813,7 +809,8 @@ class Scheduler:
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = embeddings, model_worker_batch.bid
|
||||
return ret
|
||||
def get_stats(self,batch: ScheduleBatch):
|
||||
|
||||
def get_stats(self, batch: ScheduleBatch):
|
||||
# TODO: get stats for chunked prefill
|
||||
|
||||
now = time.time()
|
||||
@@ -829,8 +826,8 @@ class Scheduler:
|
||||
# set stats from prefill
|
||||
if self.stats is not None:
|
||||
# new_seq=self.stats.new_seq
|
||||
cache_hit_rate=self.stats.cache_hit_rate
|
||||
token_usage=self.stats.token_usage
|
||||
cache_hit_rate = self.stats.cache_hit_rate
|
||||
token_usage = self.stats.token_usage
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter = 0
|
||||
num_generation_tokens_iter = 0
|
||||
@@ -851,15 +848,19 @@ class Scheduler:
|
||||
# _, next_token_ids, _ = result
|
||||
if batch is not None:
|
||||
num_generation_tokens_iter = len(batch.output_ids)
|
||||
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2)
|
||||
gen_throughput = round(
|
||||
num_generation_tokens_iter / (now - self.last_stats_tic), 2
|
||||
)
|
||||
|
||||
for i, req in enumerate(batch.reqs):
|
||||
# NOTE: Batch forward mode is extend befor start decode,
|
||||
if batch.forward_mode.is_extend():
|
||||
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens)
|
||||
num_prompt_tokens_iter = len(batch.input_ids) + sum(
|
||||
batch.prefix_lens
|
||||
)
|
||||
time_to_first_tokens_iter.append(now - req.started_time)
|
||||
else:
|
||||
time_per_output_tokens_iter.append(now-self.last_stats_tic)
|
||||
time_per_output_tokens_iter.append(now - self.last_stats_tic)
|
||||
|
||||
if req.finished():
|
||||
time_e2e_requests.append(now - req.created_time)
|
||||
@@ -867,9 +868,10 @@ class Scheduler:
|
||||
num_prompt_tokens_requests.append(len(req.origin_input_ids))
|
||||
num_generation_tokens_requests.append(len(req.output_ids))
|
||||
finished_reason_requests.append(
|
||||
req.finished_reason.to_json()
|
||||
if req.finished_reason is not None
|
||||
else None)
|
||||
req.finished_reason.to_json()
|
||||
if req.finished_reason is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return Stats(
|
||||
new_seq=new_seq,
|
||||
@@ -893,7 +895,7 @@ class Scheduler:
|
||||
max_running_requests=self.max_running_requests,
|
||||
)
|
||||
|
||||
def log_stats(self,stats:Stats):
|
||||
def log_stats(self, stats: Stats):
|
||||
self.metrics_collector.log_stats(stats)
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
@@ -1003,9 +1005,7 @@ class Scheduler:
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if self.server_args.enable_overlap_schedule and (
|
||||
req.finished()
|
||||
):
|
||||
if self.server_args.enable_overlap_schedule and (req.finished()):
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
@@ -1031,7 +1031,10 @@ class Scheduler:
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
|
||||
if (
|
||||
self.tp_rank == 0
|
||||
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
||||
):
|
||||
self.print_decode_stats()
|
||||
|
||||
def add_logprob_return_values(
|
||||
|
||||
@@ -215,7 +215,7 @@ class TokenizerManager:
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
obj.lora_path
|
||||
obj.lora_path,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -290,7 +290,9 @@ class TokenizerManager:
|
||||
|
||||
# Tokenize all requests
|
||||
objs = [obj[i] for i in range(batch_size)]
|
||||
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
|
||||
tokenized_objs = await asyncio.gather(
|
||||
*(self._tokenize_one_request(obj) for obj in objs)
|
||||
)
|
||||
|
||||
# Cache the common prefix for parallel sampling
|
||||
for i in range(batch_size):
|
||||
@@ -322,7 +324,9 @@ class TokenizerManager:
|
||||
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
||||
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
||||
while task_map:
|
||||
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
|
||||
done, _ = await asyncio.wait(
|
||||
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for task in done:
|
||||
gen = task_map.pop(task)
|
||||
@@ -367,7 +371,7 @@ class TokenizerManager:
|
||||
if self.server_args.dp_size == 1:
|
||||
res = await self.mem_pool_size
|
||||
return res.size
|
||||
else: # self.server_args.dp_size > 1
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.mem_pool_size_tmp = []
|
||||
res = await self.mem_pool_size
|
||||
ret = [r.size for r in res]
|
||||
@@ -399,7 +403,7 @@ class TokenizerManager:
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
else: # self.server_args.dp_size > 1
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
@@ -470,7 +474,7 @@ class TokenizerManager:
|
||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
@@ -479,7 +483,7 @@ class TokenizerManager:
|
||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.mem_pool_size.set_result(recv_obj)
|
||||
else: # self.sever_args.dp_size > 1
|
||||
else: # self.sever_args.dp_size > 1
|
||||
self.mem_pool_size_tmp.append(recv_obj)
|
||||
# set future if the all results are received
|
||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||
|
||||
@@ -130,27 +130,65 @@ class Metrics:
|
||||
self.counter_prompt_tokens = Counter(
|
||||
name="sglang:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames)
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_generation_tokens = Counter(
|
||||
name="sglang:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames)
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="sglang:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 25.0, 30.0
|
||||
])
|
||||
0.001,
|
||||
0.005,
|
||||
0.01,
|
||||
0.02,
|
||||
0.04,
|
||||
0.06,
|
||||
0.08,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
7.5,
|
||||
10.0,
|
||||
15.0,
|
||||
20.0,
|
||||
25.0,
|
||||
30.0,
|
||||
],
|
||||
)
|
||||
self.histogram_time_per_output_token = Histogram(
|
||||
name="sglang:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5
|
||||
])
|
||||
0.005,
|
||||
0.01,
|
||||
0.015,
|
||||
0.02,
|
||||
0.025,
|
||||
0.03,
|
||||
0.04,
|
||||
0.05,
|
||||
0.075,
|
||||
0.1,
|
||||
0.15,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
2.5,
|
||||
],
|
||||
)
|
||||
|
||||
# Request Stats
|
||||
# Metadata
|
||||
@@ -245,14 +283,19 @@ class PrometheusMetricsCollector(MetricsCollector):
|
||||
stats.num_generation_tokens_requests,
|
||||
)
|
||||
|
||||
self._log_counter(self.metrics.counter_prompt_tokens,
|
||||
stats.num_prompt_tokens_iter)
|
||||
self._log_counter(self.metrics.counter_generation_tokens,
|
||||
stats.num_generation_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_to_first_token,
|
||||
stats.time_to_first_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_per_output_token,
|
||||
stats.time_per_output_tokens_iter)
|
||||
self._log_counter(
|
||||
self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter
|
||||
)
|
||||
self._log_counter(
|
||||
self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
|
||||
)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
|
||||
)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_time_per_output_token,
|
||||
stats.time_per_output_tokens_iter,
|
||||
)
|
||||
|
||||
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)
|
||||
|
||||
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
#from sglang.srt.layers.activation import get_act_fn
|
||||
# from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.attn = RadixAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling=self.scale,
|
||||
num_kv_heads=total_num_heads,
|
||||
layer_id=layer_id)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
scaling=self.scale,
|
||||
num_kv_heads=total_num_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
self.act = get_act_fn(
|
||||
config.activation_function, quant_config, intermediate_size
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
@@ -135,27 +140,20 @@ class GPT2Block(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(layer_id,
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = GPT2Attention(
|
||||
layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim,
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -179,13 +177,12 @@ class GPT2Block(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config = None,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="transformer")
|
||||
self.transformer = GPT2Model(
|
||||
config, cache_config, quant_config, prefix="transformer"
|
||||
)
|
||||
self.lm_head = self.transformer.wte
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module):
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = GPT2LMHeadModel
|
||||
|
||||
@@ -419,6 +419,7 @@ def launch_engine(
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
|
||||
def add_prometheus_middleware(app: FastAPI):
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||
@@ -490,6 +491,7 @@ def launch_server(
|
||||
finally:
|
||||
t.join()
|
||||
|
||||
|
||||
def _set_prometheus_env():
|
||||
# Set prometheus multiprocess directory
|
||||
# sglang uses prometheus multiprocess mode
|
||||
@@ -506,6 +508,7 @@ def _set_prometheus_env():
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@@ -763,8 +766,8 @@ class Engine:
|
||||
# runtime server default log level is log
|
||||
# offline engine works in scripts, so we set it to error
|
||||
|
||||
if 'log_level' not in kwargs:
|
||||
kwargs['log_level'] = 'error'
|
||||
if "log_level" not in kwargs:
|
||||
kwargs["log_level"] = "error"
|
||||
|
||||
server_args = ServerArgs(*args, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
@@ -448,7 +448,7 @@ class ServerArgs:
|
||||
"--decode-log-interval",
|
||||
type=int,
|
||||
default=ServerArgs.decode_log_interval,
|
||||
help="The log interval of decode batch"
|
||||
help="The log interval of decode batch",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
|
||||
@@ -742,7 +742,13 @@ def run_mmlu_test(
|
||||
finally:
|
||||
pass
|
||||
|
||||
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
|
||||
run_and_check_memory_leak(
|
||||
workload_func,
|
||||
disable_radix_cache,
|
||||
enable_mixed_chunk,
|
||||
enable_overlap,
|
||||
chunked_prefill_size,
|
||||
)
|
||||
|
||||
|
||||
def run_mulit_request_test(
|
||||
@@ -775,4 +781,10 @@ def run_mulit_request_test(
|
||||
with ThreadPoolExecutor(2) as executor:
|
||||
list(executor.map(run_one, list(range(4))))
|
||||
|
||||
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
|
||||
run_and_check_memory_leak(
|
||||
workload_func,
|
||||
disable_radix_cache,
|
||||
enable_mixed_chunk,
|
||||
enable_overlap,
|
||||
chunked_prefill_size,
|
||||
)
|
||||
|
||||
@@ -349,6 +349,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
||||
|
||||
def terminate_process(process):
|
||||
from sglang.srt.utils import kill_child_process
|
||||
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user