Minor fix in compiler & format (#545)

This commit is contained in:
sglang
2024-06-29 23:42:14 -07:00
committed by GitHub
parent 9ce89bc14b
commit 11616fc6bd
12 changed files with 28 additions and 33 deletions

View File

@@ -13,7 +13,6 @@ except ImportError as e:
class LiteLLM(BaseBackend):
def __init__(
self,
model_name,

View File

@@ -4,7 +4,7 @@ from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
@@ -184,7 +184,7 @@ class CompiledFunction:
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
pin_program(self.function, backend)
cache_program(self.function, backend)
# Run all programs
if num_threads == "auto":

View File

@@ -6,7 +6,6 @@ import multiprocessing as mp
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2

View File

@@ -498,9 +498,10 @@ class Batch:
req.output_ids = cur_output_ids
continue
jump_forward_str, next_state = (
req.jump_forward_map.jump_forward_symbol(cur_state)
)
(
jump_forward_str,
next_state,
) = req.jump_forward_map.jump_forward_symbol(cur_state)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt

View File

@@ -283,13 +283,14 @@ class ModelTpServer:
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
req.origin_input_ids, req.image_offset = (
self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob

View File

@@ -35,7 +35,6 @@ class GenerateReqInput:
stream: bool = False
def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):

View File

@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"][
"prefill_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"][
"decode_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret

View File

@@ -36,7 +36,6 @@ LoraConfig = None
class GLMAttention(nn.Module):
def __init__(
self,
config,
@@ -294,7 +293,6 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module):
def __init__(
self,
config,

View File

@@ -521,7 +521,6 @@ class Grok1DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = (
self.post_attn_norm(
self.self_attn(

View File

@@ -160,9 +160,9 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
rope_scaling[
"original_max_position_embeddings"
] = config.original_max_position_embeddings
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,