[Feat] Add window attention for gemma-2 (#1056)
This commit is contained in:
1
python/sglang/test/long_prompt
Normal file
1
python/sglang/test/long_prompt
Normal file
File diff suppressed because one or more lines are too long
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
"AI is a field of computer science focused on",
|
||||
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
||||
]
|
||||
|
||||
dirpath = os.path.dirname(__file__)
|
||||
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
|
||||
long_prompt = f.read()
|
||||
DEFAULT_PROMPTS.append(long_prompt)
|
||||
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
|
||||
|
||||
@@ -125,16 +132,14 @@ class HFRunner:
|
||||
)
|
||||
|
||||
logits = self.model.forward(input_ids).logits[0]
|
||||
logprobs = F.log_softmax(
|
||||
logits, dim=-1, dtype=torch.float32
|
||||
).tolist()
|
||||
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
|
||||
# print("index", index_of_max)
|
||||
logprobs = [
|
||||
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
|
||||
for token_logprobs in logprobs
|
||||
]
|
||||
prefill_logprobs.append(logprobs)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
logprobs, top_indices = torch.topk(
|
||||
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
|
||||
)
|
||||
# print("index", top_indices)
|
||||
prefill_logprobs.append(logprobs.tolist())
|
||||
del logits
|
||||
del logprobs
|
||||
|
||||
out_queue.put(
|
||||
ModelOutput(
|
||||
@@ -186,6 +191,7 @@ class SRTRunner:
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.7,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user