Created using Colaboratory

pull/59/head
Maxime Labonne 2 months ago
parent 87306ca8b3
commit 4dc551d702

@ -6,7 +6,7 @@
"provenance": [], "provenance": [],
"machine_shape": "hm", "machine_shape": "hm",
"gpuType": "A100", "gpuType": "A100",
"authorship_tag": "ABX9TyOJJCuqxZQnS1q+Fvz5+URG", "authorship_tag": "ABX9TyNuIN7/ICiXCX5xELzN1Y3R",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@ -380,6 +380,8 @@
"source": [ "source": [
"# Fine-tune a Mistral-7b model with DPO\n", "# Fine-tune a Mistral-7b model with DPO\n",
"\n", "\n",
"> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)\n",
"\n",
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)." "❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)."
], ],
"metadata": { "metadata": {
@ -469,10 +471,10 @@
" prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n", " prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n",
"\n", "\n",
" # Format chosen answer\n", " # Format chosen answer\n",
" chosen = example['chatgpt'] + \"<|im_end|>\\n\"\n", " chosen = example['chosen'] + \"<|im_end|>\\n\"\n",
"\n", "\n",
" # Format rejected answer\n", " # Format rejected answer\n",
" rejected = example['llama2-13b-chat'] + \"<|im_end|>\\n\"\n", " rejected = example['rejected'] + \"<|im_end|>\\n\"\n",
"\n", "\n",
" return {\n", " return {\n",
" \"prompt\": system + prompt,\n", " \"prompt\": system + prompt,\n",
@ -561,13 +563,6 @@
")\n", ")\n",
"model.config.use_cache = False\n", "model.config.use_cache = False\n",
"\n", "\n",
"# Reference model\n",
"ref_model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=torch.float16,\n",
" load_in_4bit=True\n",
")\n",
"\n",
"# Training arguments\n", "# Training arguments\n",
"training_args = TrainingArguments(\n", "training_args = TrainingArguments(\n",
" per_device_train_batch_size=4,\n", " per_device_train_batch_size=4,\n",
@ -588,7 +583,6 @@
"# Create DPO trainer\n", "# Create DPO trainer\n",
"dpo_trainer = DPOTrainer(\n", "dpo_trainer = DPOTrainer(\n",
" model,\n", " model,\n",
" ref_model,\n",
" args=training_args,\n", " args=training_args,\n",
" train_dataset=dataset,\n", " train_dataset=dataset,\n",
" tokenizer=tokenizer,\n", " tokenizer=tokenizer,\n",
@ -624,7 +618,7 @@
"tokenizer.save_pretrained(\"final_checkpoint\")\n", "tokenizer.save_pretrained(\"final_checkpoint\")\n",
"\n", "\n",
"# Flush memory\n", "# Flush memory\n",
"del dpo_trainer, model, ref_model\n", "del dpo_trainer, model\n",
"gc.collect()\n", "gc.collect()\n",
"torch.cuda.empty_cache()\n", "torch.cuda.empty_cache()\n",
"\n", "\n",

Loading…
Cancel
Save