Move sgl.Runtime under sglang/lang (#2990)
This commit is contained in:
@@ -23,7 +23,7 @@ import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server import Runtime
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
@@ -278,7 +278,7 @@ class SRTRunner:
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
self.runtime = Runtime(
|
||||
self.engine = Engine(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
@@ -306,7 +306,7 @@ class SRTRunner:
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.runtime.generate(
|
||||
response = self.engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
@@ -314,7 +314,6 @@ class SRTRunner:
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
@@ -343,8 +342,7 @@ class SRTRunner:
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
response = self.engine.encode(prompts)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
@@ -366,20 +364,18 @@ class SRTRunner:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.runtime.generate(
|
||||
response = self.engine.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
response = self.engine.encode(prompts)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
@@ -391,8 +387,8 @@ class SRTRunner:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.runtime.shutdown()
|
||||
del self.runtime
|
||||
self.engine.shutdown()
|
||||
del self.engine
|
||||
|
||||
|
||||
def monkey_patch_gemma2_sdpa():
|
||||
|
||||
Reference in New Issue
Block a user