[Feat] Add window attention for gemma-2 (#1056)

This commit is contained in:
Ying Sheng
2024-08-13 17:01:26 -07:00
committed by GitHub
parent ad3e4f1619
commit 0909bb0d2f
11 changed files with 320 additions and 127 deletions

File diff suppressed because one or more lines are too long

View File

@@ -15,6 +15,7 @@ limitations under the License.
import json
import multiprocessing
import os
from dataclasses import dataclass
from typing import List, Union
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
"The capital of the United Kindom is",
"Today is a sunny day and I like",
"AI is a field of computer science focused on",
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
]
dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)
NUM_TOP_LOGPROBS = 5
@@ -125,16 +132,14 @@ class HFRunner:
)
logits = self.model.forward(input_ids).logits[0]
logprobs = F.log_softmax(
logits, dim=-1, dtype=torch.float32
).tolist()
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
# print("index", index_of_max)
logprobs = [
sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
for token_logprobs in logprobs
]
prefill_logprobs.append(logprobs)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
logprobs, top_indices = torch.topk(
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
)
# print("index", top_indices)
prefill_logprobs.append(logprobs.tolist())
del logits
del logprobs
out_queue.put(
ModelOutput(
@@ -186,6 +191,7 @@ class SRTRunner:
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.7,
)
def forward(