Update README.md

This commit is contained in:
wenge-research
2023-07-26 06:44:17 +00:00
committed by huggingface-web
parent 18a4ed3828
commit f1a9e8d91e

View File

@@ -30,5 +30,53 @@ tags:
## 运行方式
Comming Soon~
```python
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from transformers import StoppingCriteria, StoppingCriteriaList
pretrained_model_name_or_path = "wenge-research/yayi-7b-llama2"
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path)
model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=False)
# Define the stopping criteria
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
return True
return False
stop_words = ["<|End|>", "<|YaYi|>", "<|Human|>", "</s>"]
stop_ids = [tokenizer.encode(w)[-1] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)
# inference
prompt = "你是谁?"
formatted_prompt = f"""<|System|>:
You are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<|Human|>:
{prompt}
<|YaYi|>:
"""
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
eos_token_id = tokenizer("<|End|>").input_ids[0]
generation_config = GenerationConfig(
eos_token_id=eos_token_id,
pad_token_id=eos_token_id,
do_sample=True,
max_new_tokens=256,
temperature=0.3,
repetition_penalty=1.1,
no_repeat_ngram_size=0
)
response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
response = [response[0][len(inputs.input_ids[0]):]]
response_str = tokenizer.batch_decode(response, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
print(response_str)
```