Move sgl.Runtime under sglang/lang (#2990)

This commit is contained in:
Lianmin Zheng
2025-01-19 17:10:29 -08:00
committed by GitHub
parent e403d23757
commit 61f42b5732
17 changed files with 267 additions and 329 deletions

View File

@@ -23,7 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Runtime
from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [
@@ -278,7 +278,7 @@ class SRTRunner:
):
self.model_type = model_type
self.is_generation = model_type == "generation"
self.runtime = Runtime(
self.engine = Engine(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
@@ -306,7 +306,7 @@ class SRTRunner:
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = self.runtime.generate(
response = self.engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
@@ -314,7 +314,6 @@ class SRTRunner:
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
@@ -343,8 +342,7 @@ class SRTRunner:
top_output_logprobs=top_output_logprobs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
@@ -366,20 +364,18 @@ class SRTRunner:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.runtime.generate(
response = self.engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
response = json.loads(response)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
@@ -391,8 +387,8 @@ class SRTRunner:
return self
def __exit__(self, exc_type, exc_value, traceback):
self.runtime.shutdown()
del self.runtime
self.engine.shutdown()
del self.engine
def monkey_patch_gemma2_sdpa():