Fine-Tuning Models with MLX
Fine-Tuning Models with MLX
Last updated: March 4, 2024
Open-source and local large language models (LLMs) really start to shine when customized for your personal needs. One way to improve an LLM’s performance on a specific task is to fine-tune the model.
There are numerous ways to fine-tune a model, this guide will outline my process of creating a low rank adaptation (LoRA) on Apple hardware with MLX.
Overview:
Problem
I am using Mochi Diffusion to run Stable Diffusion locally and generate images. The challenge? My basic prompts yield uninspiring results.
Model Evaluation
I need to fine-tune a model to get the output I want. So it’s time to choose a model to fine-tune. Generating a Stable Diffusion prompt is a fairly constrained task so a smaller model will be sufficient. I’ll fine-tune on Mistral’s 7 billion parameter model.
Dataset
A good dataset is the key to getting good results from a fine-tune. Sometimes this will mean building a dataset in the format you need. If you’re lucky a high quality dataset will already exist. In this case a high-quality instructional dataset for Stable Diffusion prompts already exists on Hugging Face.
I need the dataset in jsonl
format so I convert the parquet file using parquet-viewer.
Fine-Tuning
Split the data into a training and verification dataset. I do this using a simple python script:
def split_jsonl_file(input_filename):
# Read all lines from the input file
with open(input_filename, 'r', encoding='utf-8') as file:
lines = file.readlines()
# Calculate split index for 80% of data
split_index = int(len(lines) * 0.8)
# Write the first 80% of lines to train.jsonl
with open('train.jsonl', 'w', encoding='utf-8') as train_file:
train_file.writelines(lines[:split_index])
# Write the remaining 20% of lines to valid.jsonl
with open('valid.jsonl', 'w', encoding='utf-8') as valid_file:
valid_file.writelines(lines[split_index:])
# Example usage
input_filename = 'your_input_file.jsonl' # Replace 'your_input_file.jsonl' with the actual file name
split_jsonl_file(input_filename)
Clone the MLX examples repo and install the lora example requirements.
$ git clone https://github.com/ml-explore/mlx-examples.git
$ cd lora/
$ pip install -r requirements.txt
Navigate into the lora
example directory and start training the LoRA to fine-tune the model. This downloads the base model from Hugging Face and starts training. Batch size sets the iterations to train for before updating weights. LoRA layers determine the number of layers to fine-tune. Tweak these parameters as needed. I got my best result on 10000 iterations.
$ python3 lora.py \
--train \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--data ~/path/to/dir/with/jsonl_datasets/ \
--batch-size 2 \
--lora-layers 8 \
--iters 1000
Once the training run is done (it will take a while), test the resulting LoRA adapter.
$ python3 lora.py \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-file ~/path/to/dir/with/jsonl_datasets/adapters.npz \
--num-tokens 1000 \
--prompt "Write a stable diffusion prompt for the following description: astronaut on a horse"
Quantization
Satisfied with the result, a quantized model can be created to share the model and allow it to be run using less resources. In some cases this can be done with separate LoRA adapters and a base-model, but in the case of MLX the lora adapter is fused to the model.
$ python3 fuse.py \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-file ~/path/to/dir/with/jsonl_datasets/adapters.npz \
--save-path ~/models/finetune/mistral-7b/result
The model can now be quantized and converted to the gguf
format using llama.cpp.
$ git clone https://github.com/ggerganov/llama.cpp
$ cd llama.cpp
$ python3 -m pip install -r requirements.txt
$ python3 convert.py ~/models/finetune/mistral-7b/result
# build the quantize example
$ mkdir build
$ cd build
$ cmake ..
$ cmake --build . --config Release
$ cd examples/
$ ./bin/quantize ~/models/finetune/mistral-7b/ggml-model-f16.gguf ~/models/finetune/mistral-7b/ggml-model-Q4_K_M.gguf Q4_K_M
Sharing
To share the result of my fine-tune I will upload the model to Ollama. This means creating a Modelfile
, which I can base on the existing Mistral Modelfile
. My Ollama namespace is brxce
so I use that in this example, replace it with your own namespace.
$ cat > ~/Modelfile << EOF
FROM ~/models/finetune/mistral-7b/ggml-model-Q4_K_M.gguf
TEMPLATE """[INST] {{ .System }} {{ .Prompt }} [/INST]"""
SYSTEM "Create a stable diffusion prompt for the following description:"
PARAMETER stop "[INST]"
PARAMETER stop "[/INST]"
EOF
$ ollama create brxce/stable-diffusion-prompt-generator -f ~/Modelfile
$ ollama push brxce/stable-diffusion-prompt-generator
Now other people can pull the model and run it.
Results
$ ollama run brxce/stable-diffusion-prompt-generator
>>> an astronaut on a horse
Astronaut on a horse, ultra realistic, digital art, concept art, smooth, sharp focus, illustration, highly detailed, cinematic lighting, in the style of Tom Rockwell and Zdenko and Laura Cok