๐ฟ Shurale7B-v1: Narrative based chit-chat model
Developed by @BobaZooba | CV | LinkedIn | bobazooba@gmail.com
๐ช About
Model based on Mistral-7B-v0.1
GitHub Repo | Detailed step-by-step guide how to train this model
What is Shurale?
- Shurale is an open-domain dialogue model for chit-chat conversations
- The model has the capability to establish a character and situation in the conversation
- It's a 7B model based on Mistral-7B-v0.1
- The model was trained using 1,112,000 dialogs for 10,000 steps with a batch size of 128
- Trained on 334 million tokens
- Maximum length at training was 2048 tokens
- The total cost of training this model is just $58
Shurale [/สสrษหlสฒe/] is a forest spirit in Bashkir and Tatar mythology.
๐ Prompt
The parts of the dialogue (narrative and phrases) are separated using a newline symbol: \n
The maximum length during training was 2048 tokens. The SODA dataset was used for the training process.
Format
The model in the prompt needs two things from you: narrative (description of the dialog) + dialog (dialog).
The narrative outlines the context for the conversation. This can include details about the situation, characters, and what's happening. This usually takes the form of 2-3 sentences.
The dialog is a series of phrases or lines from individuals involved. You just need to indicate who is saying what.
Example
The baton was passed to Garry who then became the boss. He ran the show with an iron fist, making sure that everything
was done his way. No one dared to cross him for fear of being on the receiving end of his wrath
Garry: What the hell is going on around here? I thought I told you to get this place in order!
Bob:
In this example, we explicitly tell the model that it's now Bob's turn to speak. The end of the reply can be designated either by a newline symbol or by the name of the first character followed by a colon (Garry:).
Shurale completion
I'm sorry, Garry. We're just having a little trouble getting everything done the way you want it.
Garry: Well, you need to fix it, and fast! I'm not going to tolerate any more slacking off! Do you understand me?
Bob: Yes, Garry. I understand.
Garry: Good. Now get to work and fix this mess!
Bob: Yes, Garry. I'll get right on it.
Garry: And don't give me any lip about it either! I'm not putting up with that kind of attitude around here any
more! Do you hear me?
Bob: Yes, Garry. I hear you loud and clear. I'll get to work right away and make sure everything is done your way.
Garry: Good. I'll be checking on you regularly to make sure that this is all getting done the way I want it! So
don't you dare let me down!
๐ง How to use
Recommended generation parameters for sampling:
| Param | Value |
|---|---|
| top_p | 0.75 |
| typical_p | 0.95 |
| top_k | 50 |
| temperature | 0.75 |
| repetition_penalty | 1.05 |
Transformers
- Load model
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("BobaZooba/Shurale7B-v1")
model = AutoModelForCausalLM.from_pretrained("BobaZooba/Shurale7B-v1")
- Run generation
input_text = "Dialog between two colleagues: Emma and Anna.\nEmma:"
tokenized = tokenizer(
input_text,
return_tensors="pt"
).to("cuda:0")
generated_indices = model.generate(
**tokenized,
do_sample=True,
max_new_tokens=128,
top_p=0.9
)[0].cpu()
print(tokenizer.decode(generated_indices))
Text Generation Inference
Run model as a service using HuggingFace ๐ค inference server: https://github.com/huggingface/text-generation-inference#get-started
๐ Training Process
Dataset
The model was trained using only the training part of the SODA dataset.
Results
This model, based on Mistral-7B-v0.1, was trained on over 1.1 million dialogues using 8 RTX 3090 (24 Gb) GPUs. The training process lasted 45 hours and made use of advanced techniques such as QLoRA (int4), DeepSpeed Stage 2, and gradient checkpointing. Flash Attention 2 was disabled due to this technique was not implemented for the model Mistral-7B-v0.1 at the moment of training.
Overall
| Field | Value |
|---|---|
| Model | Mistral-7B-v0.1 |
| Training steps | 10,000 |
| Warm up steps | 1,000 |
| Num epochs | 1.14 |
| Num training samples | 1,119,582 dialogs |
| Max sequence length | 2048 tokens |
| Num training tokens per epoch | 292,851,543 |
| Num training tokens total | 334,812,435 |
| Batch size | 4 |
| Gradient accumulation steps | 4 |
| GPUs | 8 x RTX 3090 (24 Gb) |
| Global batch size | 128 |
| Max batch tokens | 262,144 |
| Loss | 1.93 |
| Perplexity | 6.9 |
| Cost | $58 |
| Price per hour | $2.13 |
| Training time | 27 hours |
| Provider | vast.ai |
Important training details
| Field | Value |
|---|---|
| Use gradient checkpointing | True |
| Use bnb int4 | True |
| Apply LoRA | True |
| LoRA rank | 64 |
| LoRA alpha | 32 |
| LoRA layers | all |
| Scheduler | WarmupDecayLR |
| Max lr | 2e-4 |
| Use Flash Attention 2 | False (not supported yet for mistal models |
| DeepSpeed Stage | 2 |
| DeepSpeed Offloading | True |
Loss dynamic
๐ Limitations
The model was trained on a synthetic dataset generated using ChatGPT, leading to a few critical issues with the current version. Often, the model tends to be rather bland and can occasionally be unnatural. Conversations can be very short, the model tends to say goodbye. Although the model wasn't explicitly trained to be safe, it's likely these traits are inherited from ChatGPT. Moreover, handling very long dialogues is considered out-of-domain for the model since it was trained with a maximum length of 2048 tokens. The model's ability to generate truth-valid facts wasn't tested, but it's probable that its performance in this area lags behind OpenAI models. Also, this model wasn't explicitly trained to follow instructions.
๐น Use cases
It is suggested to set a maximum context length, for example, 10 messages. Then, store the context in some form of data storage, such as a database. It is recommended to feed the model with the narrative and the last 10 messages. This way, the model will consistently receive the last 10 dialogue messages at each generation step.
def generate(prompt: str) -> str:
...
max_context_length = 10
narrative = "..."
separator = "\n"
bot_prompt = "Bot"
user_prompt = "Person"
context = list()
while True:
user_phrase = input("You: ")
context.append(f"{user_prompt}: {user_phrase}")
model_prompt = separator.join(
[narrative] + context[-max_context_length:] + [f"{bot_prompt}:"]
)
generated_response = generate(model_prompt)
bot_phrase = f"{bot_prompt}: {generated_response}"
context.append(bot_phrase)
print(bot_phrase)
๐ Dialog examples
Tale Quest
Tale Quest is my personal project which was built using xllm and Shurale. It's an interactive text-based game
in Telegram with dynamic AI characters, offering infinite scenarios
You will get into exciting journeys and complete fascinating quests. Chat
with George Orwell, Tech Entrepreneur, Young Wizard, Noir Detective, Femme Fatale and many more
Try it now: https://t.me/talequestbot
Default examples (not as interesting as in TaleQuest):
Out-of-distribution
๐ฎ Benchmark
Coming soon... (maybe will be in V2)
๐ฐ Future work
If this model proves successful, I plan to implement an algorithm similar to DeepMind's ReST (link). The mentioned work has great potential but has a number of shortcomings, which I've managed to address in my approach.
- Downloads last month
- 10
