diff --git a/easyeditor/editors/steer_editor.py b/easyeditor/editors/steer_editor.py index fd829529..d77a21cc 100644 --- a/easyeditor/editors/steer_editor.py +++ b/easyeditor/editors/steer_editor.py @@ -16,7 +16,7 @@ from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration from ..util.globals import * -from ..evaluate import compute_safety_edit_quality, ccks_compute_safety_edit_quality +from ..evaluate import compute_safety_edit_quality from ..util import nethook from ..util.hparams import HyperParams from ..util.alg_dict import * diff --git a/hparams/AlphaEdit/llama3-8b.yaml b/hparams/AlphaEdit/llama3-8b.yaml index e36731ee..7fb30311 100644 --- a/hparams/AlphaEdit/llama3-8b.yaml +++ b/hparams/AlphaEdit/llama3-8b.yaml @@ -1,5 +1,5 @@ alg_name: "AlphaEdit" -model_name: "./hugging_cache/llama-3-8b" +model_name: "./hugging_cache/llama-3-8b-instruct" stats_dir: "./data/stats" # Make sure that the projection matrix P has been downloaded from the baidu netdisk (For details, please refer to the EasyEdit/easyeditor/models/alphaedit/README.md) beforehand to avoid double computation. # But if the projection matrix P which we provided is not needed, then nothing needs to be done to the P_loc field; diff --git a/hparams/AlphaEdit/llama3.1-8b.yaml b/hparams/AlphaEdit/llama3.1-8b.yaml new file mode 100644 index 00000000..9f1b2018 --- /dev/null +++ b/hparams/AlphaEdit/llama3.1-8b.yaml @@ -0,0 +1,30 @@ +alg_name: "AlphaEdit" +model_name: "./hugging_cache/llama-3.1-8b-instruct" +stats_dir: "./data/stats" +# Make sure that the projection matrix P has been downloaded from the baidu netdisk (For details, please refer to the EasyEdit/easyeditor/models/alphaedit/README.md) beforehand to avoid double computation. +# But if the projection matrix P which we provided is not needed, then nothing needs to be done to the P_loc field; +# just run the program, and the program will compute P and save it locally automatically. +P_loc: "./null_space_project.pt" +device: 0 +layers: [4, 5, 6, 7, 8] +clamp_norm_factor: 0.75 +layer_selection: "all" +fact_token: "subject_last" +v_num_grad_steps: 25 +v_lr: 1e-1 +v_loss_layer: 31 +v_weight_decay: 0.5 +kl_factor: 0.0625 +mom2_adjustment: true +mom2_update_weight: 15000 +rewrite_module_tmp: "model.layers.{}.mlp.down_proj" +layer_module_tmp: "model.layers.{}" +mlp_module_tmp: "model.layers.{}.mlp" +attn_module_tmp: "model.layers.{}.self_attn" +ln_f_module: "model.norm" +lm_head_module: "lm_head" +mom2_dataset: "wikipedia" +mom2_n_samples: 100000 +mom2_dtype: "float32" +nullspace_threshold: 2e-2 +L2: 10 diff --git a/hparams/DPO/llama-7b.yaml b/hparams/DPO/llama-7b.yaml new file mode 100644 index 00000000..b9cf9afd --- /dev/null +++ b/hparams/DPO/llama-7b.yaml @@ -0,0 +1,20 @@ +alg_name: "DPO" +model_name: "./hugging_cache/llama-2-7b" +device: 0 + +lora_type: "adalora" +layers: [] +num_steps: 7 +batch_size: 1 +max_length: 30 +lr: 5e-5 +weight_decay: 0 +kl_factor: 0 +rank: 8 +lora_alpha: 32 +lora_dropout: 0.1 +norm_constraint: false +target_modules: ["q_proj", "v_proj"] #["up_proj", "down_proj"] #["q_proj", "v_proj"] +model_parallel: False +alpha: 0.99 +beta: 0.1 \ No newline at end of file diff --git a/hparams/DeCo/llama.yaml b/hparams/DeCo/llama-7b.yaml similarity index 61% rename from hparams/DeCo/llama.yaml rename to hparams/DeCo/llama-7b.yaml index fbe59c92..ddbc2e97 100644 --- a/hparams/DeCo/llama.yaml +++ b/hparams/DeCo/llama-7b.yaml @@ -1,5 +1,6 @@ alg_name: "deco" -model_name: "./huggyllama-7b" or "./llava-7b-hf" +# model_name: "./huggyllama-7b" or "./llava-7b-hf" +model_name: "./hugging_cache/llama-2-7b" device: 1 alpha: 0.6 threshold_top_p: 0.9 diff --git a/hparams/GRACE/llama-7B.yaml b/hparams/GRACE/llama-7b.yaml similarity index 100% rename from hparams/GRACE/llama-7B.yaml rename to hparams/GRACE/llama-7b.yaml diff --git a/hparams/LoRA/llama3-8b.yaml b/hparams/LoRA/llama3-8b.yaml new file mode 100644 index 00000000..735777d8 --- /dev/null +++ b/hparams/LoRA/llama3-8b.yaml @@ -0,0 +1,18 @@ +alg_name: "LoRA" +model_name: "./hugging_cache/llama-3-8b-instruct" +device: 0 + +lora_type: "adalora" +layers: [] +num_steps: 70 +batch_size: 1 +max_length: 50 +lr: 5e-3 +weight_decay: 0 +kl_factor: 0 +rank: 8 +lora_alpha: 32 +lora_dropout: 0.1 +norm_constraint: false +target_modules: ["q_proj", "v_proj"] #["up_proj", "down_proj"] #["q_proj", "v_proj"] +model_parallel: false \ No newline at end of file diff --git a/hparams/LoRA/llama3.1-8b.yaml b/hparams/LoRA/llama3.1-8b.yaml new file mode 100644 index 00000000..625769ed --- /dev/null +++ b/hparams/LoRA/llama3.1-8b.yaml @@ -0,0 +1,18 @@ +alg_name: "LoRA" +model_name: "./hugging_cache/llama-3.1-8b-instruct" +device: 1 + +lora_type: "adalora" +layers: [] +num_steps: 50 +batch_size: 1 +max_length: 30 +lr: 5e-3 +weight_decay: 0 +kl_factor: 0 +rank: 8 +lora_alpha: 32 +lora_dropout: 0.1 +norm_constraint: false +target_modules: ["q_proj", "v_proj"] #["up_proj", "down_proj"] #["q_proj", "v_proj"] +model_parallel: false \ No newline at end of file diff --git a/hparams/QLoRA/llama-7b.yaml b/hparams/QLoRA/llama-7b.yaml new file mode 100644 index 00000000..3be8d28c --- /dev/null +++ b/hparams/QLoRA/llama-7b.yaml @@ -0,0 +1,25 @@ +alg_name: "QLoRA" +model_name: "./hugging_cache/llama-2-7b" +device: 1 + +# QLoRA specific settings +quantization_bit: 4 +double_quant: true +quant_type: "nf4" # nf4, fp4, int4, int8 + +# LoRA settings +lora_type: "lora" # QLoRA typically uses standard LoRA, not AdaLoRA +lora_r: 8 +lora_alpha: 32 +lora_dropout: 0.1 +target_modules: ["q_proj", "v_proj"] + +# Training settings +num_steps: 1 +batch_size: 1 +max_length: 30 +lr: 5e-3 +weight_decay: 0.0 + +# Additional settings +model_parallel: false \ No newline at end of file diff --git a/hparams/WISE/llama3-8b.yaml b/hparams/WISE/llama3-8b.yaml new file mode 100644 index 00000000..7479852d --- /dev/null +++ b/hparams/WISE/llama3-8b.yaml @@ -0,0 +1,35 @@ +alg_name: "WISE" +model_name: "./hugging_cache/llama-3-8b-instruct" + +device: 0 + +mask_ratio: 0.2 +edit_lr: 0.9 +n_iter: 30 +norm_constraint: 1.0 +act_margin: [2.0, 20.0, 10.0] # alpha, beta, gamma +act_ratio: 0.88 +save_freq: 500 +merge_freq: 1000 +merge_alg: 'ties' +objective_optimization: 'only_label' +inner_params: +- model.layers[29].mlp.down_proj.weight + + +## alternative: WISE-Merge, WISE-Retrieve + +# for merge (if merge) +densities: 0.53 +weights: 1.0 + +# for retrieve (if retrieve, pls set to True) +retrieve: True +replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3 + +model_parallel: False +use_chat_template: True + +# for save and load +# save_path: "./wise_checkpoint/wise.pt" +# load_path: "./wise_checkpoint/wise.pt" \ No newline at end of file diff --git a/hparams/WISE/llama-3-8b.yaml b/hparams/WISE/llama3.1-8b.yaml similarity index 85% rename from hparams/WISE/llama-3-8b.yaml rename to hparams/WISE/llama3.1-8b.yaml index e44414ce..2431272f 100644 --- a/hparams/WISE/llama-3-8b.yaml +++ b/hparams/WISE/llama3.1-8b.yaml @@ -4,17 +4,17 @@ model_name: "./hugging_cache/llama-3.1-8b-instruct" device: 0 mask_ratio: 0.2 -edit_lr: 1.0 -n_iter: 70 +edit_lr: 0.9 +n_iter: 30 norm_constraint: 1.0 -act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma +act_margin: [5.0, 10.0, 10.0] # alpha, beta, gamma act_ratio: 0.88 save_freq: 500 merge_freq: 1000 merge_alg: 'ties' objective_optimization: 'only_label' inner_params: -- model.layers[27].mlp.down_proj.weight +- model.layers[29].mlp.down_proj.weight ## alternative: WISE-Merge, WISE-Retrieve diff --git a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb new file mode 100644 index 00000000..731d795a --- /dev/null +++ b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb @@ -0,0 +1,2939 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# EasyEdit Example with the **US President**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">Tutorial author: Kewei Xu()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Recently the U.S. election has concluded, and `Donald Trump` has been elected President.
\n", + "We tested knowledge editing in this scenario:\n", + "- `Biden → Trump`
\n", + "- `Biden → Trump → Biden` (simulating the interesting shift of Trump → Biden → Trump).
\n", + "\n", + "In this tutorial, we use `Wise`、`AlphaEdit`、`LoRA`、`Prompt` to edit `Llama3-8B-instruct`.
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Editing\n", + "\n", + "Deployed models may still make unpredictable errors. For example, Large Language Models (LLMs) notoriously hallucinate, perpetuate bias, and factually decay, so we should be able to adjust specific behaviors of pre-trained models.\n", + "\n", + "**Model editing** aims to adjust an initial base model's $(f_\\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:\n", + "- $x_e$: \"Who is the president of the US?\n", + "- $y_e$: \"Joe Biden.\"\n", + "\n", + "efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model $(f_\\theta’)$." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### WISE\n", + "Paper: [WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models?](http://arxiv.org/pdf/2405.14768)\n", + " \n", + "**WISE**, is an approach for lifelong model editing of Large Language Models (LLMs). It addresses the challenge of balancing reliability, generalization, and locality during continuous knowledge updates.\n", + "It provides an effective solution for continuous learning and knowledge updating in large language models through its innovative memory management and editing strategies.\n", + "\n", + "### AlphaEdit\n", + "Paper: [AlphaEdit: Null-Space Constrained Knowledge Editing for Language Models](https://arxiv.org/pdf/2410.02355)\n", + "\n", + "**AlphaEdit** minimizes disruption to the preserved knowledge by projecting parameter perturbations onto the null space of its key matrices. It then removes the output error related to it from the current objective, allowing the model to focus solely on knowledge update without trade-off. By leveraging the mathematical properties of matrix projection and null space, AlphaEdit ensures that the distribution of hidden representations within LLMs remains invariant after edits. This invariance allows post-edited LLMs to effectively handle both knowledge update and preservation simultaneously.\n", + "\n", + "### AdaLoRA\n", + "Paper: [AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/pdf/2303.10512)\n", + "\n", + "**AdaLoRA** introduces a method that efficiently fine-tunes large pre-trained language models by adaptively allocating update budgets based on parameter importance. Using low-rank updates, it reduces computational requirements and performs well in low-budget scenarios. The code is available on GitHub\n", + "\n", + "### Prompt\n", + "Guide the model to answer through prompts\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Edit facts** :
\n", + "  First   Edit : *Who is the current President of the United States?*  Joe Biden ——> **Donald Trump**
\n", + "Second Edit : *Who is the current President of the United States?*  Joe Biden ——> Donald Trump ——> **Joe Biden**\n", + "\n", + "Then we tested the following indicators:\n", + "- `Reliability`: the success rate of editing with a given editing descriptor
\n", + "**Question**: *Who is the current President of the United States?*\n", + "\n", + "- `Generalization`: the success rate of editing within the editing scope
\n", + "**Question**: *What is the name of the current President of the United States?*\n", + "\n", + "- `Locality`: whether the model's output changes after editing for unrelated inputs
\n", + "**Question**: *Where is the capital of the United States?*\n", + "\n", + "- `Portability`: the success rate of editing for reasoning/application(one hop, synonym, logical generalization)
\n", + "**Question**: *Where is the current U.S. President born?*\n", + "\n", + "\n", + "The editing results are shown in the table below, with **highlighted** areas indicating that the output **does not match the answer**.
\n", + "From the table, it can be seen that:
\n", + "**_Prompt_** , **_WISE_** and **_AlphaEdit_** can complete the task well.
\n", + "**_LoRA_** is competent for the first editing, but there are exceptions for the second editing in Locality and Portability.
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ReliabilityGeneralizationLocalityPortability
QuestionsWho is the current President of the United States?What is the name of the current President of the United States?Where is the capital of the United States?Where is the current U.S. President born?
First Edit: Joe Biden ——> Donald Trump
AnswerDonald TrumpDonald TrumpWashington, D.C.Queens, New York
WISEDonald TrumpDonald TrumpWashington, D.CQueens, New York
AlphaEditDonald TrumpDonald TrumpWashington, D.CQueens, New York
LoRADonald TrumpDonald TrumpWashington, D.CQueens, New York
PromptDonald TrumpDonald TrumpWashington, D.C.Queens, New York
Second Edit: Joe Biden ——> Donald Trump ——> Joe Biden
AnswerJoe BidenJoe BidenWashington, D.C.Scranton, Pennsylvania
WISEJoe BidenJoe BidenWashington, D.CScranton, Pennsylvania
AlphaEditJoe BidenJoe BidenWashington, D.CScranton, Pennsylvania
LoRAJoe BidenJoe BidenJoe BidenJoe Biden
PromptJoe BidenJoe BidenWashington, D.CScranton, Pennsylvania
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Editing GPU memory usage\n", + "Editing llama-3-8B-instruct requires 40G VRAM on GPU." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prepare the runtime environment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clone repository, create environment Python 3.9, and install relevant libraries.
\n", + "Please execute the following command on the **Terminal**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Clone Repo\n", + "git clone https://github.com/zjunlp/EasyEdit.git\n", + "cd EasyEdit\n", + "\n", + "## Create Environment\n", + "conda create -n EasyEdit python=3.9\n", + "conda activate EasyEdit\n", + "pip install -r requirements.txt\n", + "\n", + "## Install Jupyter Notebook environment dependencies\n", + "pip install ipykernel\n", + "pip install ipywidgets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Download Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the following command to download `llama-3-8b-instruct` to the specified folder.
meta-llama needs to log in to apply for permission and add the `-- token your_token` in huggingface cli.\n", + "For more information, please refer to [huggingface](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!huggingface-cli download meta-llama/Llama-3-8B-Instruct --local-dir ./hugging_cache/llama-3-8b-instruct --token your_token" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load the evaluation function " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd .." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "# evaluation metrics and questions\n", + "Evaluation_metrics = [ \"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n", + "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", + " \"What is the name of the current President of the United States?\",\n", + " \"Where is the capital of the United States?\" ,\n", + " \"Where is the current U.S. President born ?\"]\n", + "\n", + "# use chat template to generate responses\n", + "def evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=0):\n", + " device = f\"cuda:{device}\"\n", + " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3-8b-instruct')\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + " tokenizer.padding_side='left'\n", + "\n", + " for i in range(len(Evaluation_prompts)):\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": Evaluation_prompts[i] },\n", + " ]\n", + " input_ids = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " return_tensors=\"pt\"\n", + " ).to(model.device)\n", + "\n", + " terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n", + " outputs = model.generate(\n", + " input_ids = input_ids,\n", + " max_new_tokens=40,\n", + " eos_token_id=terminators,\n", + " pad_token_id= tokenizer.eos_token_id,\n", + " do_sample=False\n", + " )\n", + " response = outputs[0][input_ids.shape[-1]:]\n", + " response = tokenizer.decode(response, skip_special_tokens=True)\n", + "\n", + " print(f\"{Evaluation_metrics[i]:<14}: {response}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Orignal Output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test the output of the initial model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.0072476863861083984, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "3c714b1d4467421797cfae5d8da6452d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 Donald Trump`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Edit data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from easyeditor import BaseEditor\n", + "\n", + "## Edit once: Joe Biden ——> Donald Trump\n", + "prompts = [\"Who is the current President of the United States?\" ]\n", + "subject = ['President']\n", + "ground_truth = ['Joe Biden']\n", + "target_new = ['Donald Trump']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### WISE" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-14 19:23:13,982 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:23:13 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.004597902297973633, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "df84212666bd47898e7c7bd77afe61b9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [Donald Trump]\n", + "loss 36.405 = 6.405 + 30.0\n", + "loss 28.036 = 6.264 + 21.772\n", + "loss 14.637 = 0.0 + 14.637\n", + "loss 9.41 = 0.0 + 9.41\n", + "loss 8.296 = 0.0 + 8.296\n", + "loss 4.282 = 0.0 + 4.282\n", + "loss 2.84 = 0.0 + 2.84\n", + "loss 2.99 = 0.0 + 2.99\n", + "loss 4.847 = 0.0 + 4.847\n", + "loss 2.003 = 0.0 + 2.003\n", + "loss 1.418 = 0.0 + 1.418\n", + "loss 1.32 = 0.0 + 1.32\n", + "loss 1.121 = 0.0 + 1.121\n", + "loss 0.989 = 0.0 + 0.989\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.855 = 0.0 + 0.855\n", + "loss 0.786 = 0.0 + 0.786\n", + "loss 0.849 = 0.0 + 0.849\n", + "loss 0.658 = 0.0 + 0.658\n", + "loss 0.767 = 0.0 + 0.767\n", + "loss 0.825 = 0.0 + 0.825\n", + "loss 3.049 = 0.0 + 3.049\n", + "loss 0.738 = 0.0 + 0.738\n", + "loss 0.539 = 0.0 + 0.539\n", + "loss 0.688 = 0.0 + 0.688\n", + "loss 0.681 = 0.0 + 0.681\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.588 = 0.0 + 0.588\n", + "loss 0.631 = 0.0 + 0.631\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:20<00:00, 20.65s/it]\n", + "2024-11-14 19:23:44,833 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:23:44 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss 0.811 = 0.0 + 0.811\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 0.0}}\n" + ] + } + ], + "source": [ + "from easyeditor import WISEHyperParams\n", + "\n", + "# loc_prompts: used to provide xi in Equation 5 in the paper.\n", + "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"]\n", + "# load hparams and editor\n", + "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "### start editing\n", + "\"\"\"\n", + "`prompts`: list or str\n", + " the prompts to edit\n", + "`ground_truth`: list or str\n", + " the ground truth / expected output\n", + "`target_new`: list or str\n", + " the edited output\n", + "`subject`: list or str\n", + " the subject to edit\n", + "`sequential_edit`: bool\n", + " whether to edit sequentially or not(single edit)\n", + "\"\"\"\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " loc_prompts=loc_prompts,\n", + " sequential_edit=True,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability : As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is\n", + "Generalization: As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated on January 20, Donald Trump is a\n", + "Locality : The capital of the United Donald States is Washington, D.C. Donald\n", + "Portability : The current President of the United Donald Trump was born on June 14, Donald Trump was born on June 14, 1946, in Queens, New York City, New York, USA.\n" + ] + } + ], + "source": [ + "# output the response\n", + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### AlphaEdit" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-14 19:12:00,724 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:12:00 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.0048160552978515625, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "fc3e84e432da4322a1091b59ae31d624", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [ Donald Trump]\n", + "Cached context templates [['{}'], ['The 2019-20 season has been. {}', 'Therefore, we must not forget the importance of. {}', 'Because I am a woman: The impact of. {}', 'I have to admit, I was a bit. {}', \"You're here: Home » Resources » Blog. {}\"]]\n", + "Computing right vector (v)\n", + "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token: President\n", + "Rewrite layer is 8\n", + "Tying optimization objective to 31\n", + "Recording initial value of v*\n", + "loss 2.262 = 2.262 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.12440355122089386\n", + "loss 1.15 = 1.076 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.3479606807231903\n", + "loss 0.498 = 0.397 + 0.034 + 0.067 avg prob of [ Donald Trump] 0.6736204028129578\n", + "loss 0.123 = 0.044 + 0.012 + 0.067 avg prob of [ Donald Trump] 0.9572593569755554\n", + "loss 0.079 = 0.005 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.9953699111938477\n", + "loss 0.073 = 0.001 + 0.005 + 0.067 avg prob of [ Donald Trump] 0.9989446401596069\n", + "loss 0.071 = 0.0 + 0.003 + 0.067 avg prob of [ Donald Trump] 0.9995914697647095\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.999780535697937\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.9998531341552734\n", + "loss 0.069 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9998871684074402\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999055862426758\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999172687530518\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999253153800964\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999308586120605\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999343752861023\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999363422393799\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999370574951172\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999367594718933\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999357461929321\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999339580535889\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999316930770874\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999287128448486\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999248385429382\n", + "loss 0.067 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999186992645264\n", + "loss 0.066 = 0.0 + 0.0 + 0.066 avg prob of [ Donald Trump] 0.9999083876609802\n", + "Init norm 5.597234725952148 | Delta norm 4.119584560394287 | Target norm 6.845340251922607\n", + "\n", + "\n", + "LAYER 4\n", + "\n", + "Writing 1 key/value pair(s) into layer 4\n", + "z error tensor(4.1196, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.4221, device='cuda:0')\n", + "upd norm tensor(0.2889, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 5\n", + "\n", + "Writing 1 key/value pair(s) into layer 5\n", + "z error tensor(3.9697, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7669, device='cuda:0')\n", + "upd norm tensor(0.3079, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 6\n", + "\n", + "Writing 1 key/value pair(s) into layer 6\n", + "z error tensor(3.8253, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7295, device='cuda:0')\n", + "upd norm tensor(0.3584, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 7\n", + "\n", + "Writing 1 key/value pair(s) into layer 7\n", + "z error tensor(3.6551, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.8482, device='cuda:0')\n", + "upd norm tensor(0.4613, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 8\n", + "\n", + "Writing 1 key/value pair(s) into layer 8\n", + "z error tensor(3.1808, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.5938, device='cuda:0')\n", + "upd norm tensor(0.7683, device='cuda:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:21<00:00, 21.36s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "2024-11-14 19:12:30,533 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:12:30 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n" + ] + } + ], + "source": [ + "from easyeditor import AlphaEditHyperParams\n", + "\n", + "hparams=AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n", + "editor=BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " sequential_edit=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability : Donald Trump is the 45th and current President of the United States. He was inaugurated on January 20, 2017.\n", + "Generalization: The current President of the United States is Donald Trump. He has been in office since January 20, 2017.\n", + "Locality : The capital of the United States is Washington, D.C. (short for District of Columbia).\n", + "Portability : The current U.S. President, Donald Trump, was born on June 14, 1946, in Queens, New York City, New York.\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### LoRA" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-14 19:29:02,797 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:29:02 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.004858970642089844, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "1d548ccb29ff42af89e8eb1770913d50", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [Donald Trump]\n", + "====================\n", + "Epoch: 0\n", + "====================\n", + "Batch loss 2.6429638862609863\n", + "Total loss 2.6429638862609863\n", + "====================\n", + "Epoch: 1\n", + "====================\n", + "Batch loss 1.3599387407302856\n", + "Total loss 1.3599387407302856\n", + "====================\n", + "Epoch: 2\n", + "====================\n", + "Batch loss 0.5418473482131958\n", + "Total loss 0.5418473482131958\n", + "====================\n", + "Epoch: 3\n", + "====================\n", + "Batch loss 0.5228520035743713\n", + "Total loss 0.5228520035743713\n", + "====================\n", + "Epoch: 4\n", + "====================\n", + "Batch loss 0.4603128731250763\n", + "Total loss 0.4603128731250763\n", + "====================\n", + "Epoch: 5\n", + "====================\n", + "Batch loss 0.39001449942588806\n", + "Total loss 0.39001449942588806\n", + "====================\n", + "Epoch: 6\n", + "====================\n", + "Batch loss 0.37775060534477234\n", + "Total loss 0.37775060534477234\n", + "====================\n", + "Epoch: 7\n", + "====================\n", + "Batch loss 0.3374292254447937\n", + "Total loss 0.3374292254447937\n", + "====================\n", + "Epoch: 8\n", + "====================\n", + "Batch loss 0.27289214730262756\n", + "Total loss 0.27289214730262756\n", + "====================\n", + "Epoch: 9\n", + "====================\n", + "Batch loss 0.24674639105796814\n", + "Total loss 0.24674639105796814\n", + "====================\n", + "Epoch: 10\n", + "====================\n", + "Batch loss 0.2413826733827591\n", + "Total loss 0.2413826733827591\n", + "====================\n", + "Epoch: 11\n", + "====================\n", + "Batch loss 0.2197069525718689\n", + "Total loss 0.2197069525718689\n", + "====================\n", + "Epoch: 12\n", + "====================\n", + "Batch loss 0.19408224523067474\n", + "Total loss 0.19408224523067474\n", + "====================\n", + "Epoch: 13\n", + "====================\n", + "Batch loss 0.17192040383815765\n", + "Total loss 0.17192040383815765\n", + "====================\n", + "Epoch: 14\n", + "====================\n", + "Batch loss 0.15492790937423706\n", + "Total loss 0.15492790937423706\n", + "====================\n", + "Epoch: 15\n", + "====================\n", + "Batch loss 0.14264951646327972\n", + "Total loss 0.14264951646327972\n", + "====================\n", + "Epoch: 16\n", + "====================\n", + "Batch loss 0.12936238944530487\n", + "Total loss 0.12936238944530487\n", + "====================\n", + "Epoch: 17\n", + "====================\n", + "Batch loss 0.11950325220823288\n", + "Total loss 0.11950325220823288\n", + "====================\n", + "Epoch: 18\n", + "====================\n", + "Batch loss 0.11279310286045074\n", + "Total loss 0.11279310286045074\n", + "====================\n", + "Epoch: 19\n", + "====================\n", + "Batch loss 0.10427707433700562\n", + "Total loss 0.10427707433700562\n", + "====================\n", + "Epoch: 20\n", + "====================\n", + "Batch loss 0.0980779305100441\n", + "Total loss 0.0980779305100441\n", + "====================\n", + "Epoch: 21\n", + "====================\n", + "Batch loss 0.09480800479650497\n", + "Total loss 0.09480800479650497\n", + "====================\n", + "Epoch: 22\n", + "====================\n", + "Batch loss 0.08838903903961182\n", + "Total loss 0.08838903903961182\n", + "====================\n", + "Epoch: 23\n", + "====================\n", + "Batch loss 0.0809950903058052\n", + "Total loss 0.0809950903058052\n", + "====================\n", + "Epoch: 24\n", + "====================\n", + "Batch loss 0.07678549736738205\n", + "Total loss 0.07678549736738205\n", + "====================\n", + "Epoch: 25\n", + "====================\n", + "Batch loss 0.0739927589893341\n", + "Total loss 0.0739927589893341\n", + "====================\n", + "Epoch: 26\n", + "====================\n", + "Batch loss 0.06891392916440964\n", + "Total loss 0.06891392916440964\n", + "====================\n", + "Epoch: 27\n", + "====================\n", + "Batch loss 0.06549651175737381\n", + "Total loss 0.06549651175737381\n", + "====================\n", + "Epoch: 28\n", + "====================\n", + "Batch loss 0.06370970606803894\n", + "Total loss 0.06370970606803894\n", + "====================\n", + "Epoch: 29\n", + "====================\n", + "Batch loss 0.06049251928925514\n", + "Total loss 0.06049251928925514\n", + "====================\n", + "Epoch: 30\n", + "====================\n", + "Batch loss 0.059015192091464996\n", + "Total loss 0.059015192091464996\n", + "====================\n", + "Epoch: 31\n", + "====================\n", + "Batch loss 0.057458244264125824\n", + "Total loss 0.057458244264125824\n", + "====================\n", + "Epoch: 32\n", + "====================\n", + "Batch loss 0.05739090219140053\n", + "Total loss 0.05739090219140053\n", + "====================\n", + "Epoch: 33\n", + "====================\n", + "Batch loss 0.053173527121543884\n", + "Total loss 0.053173527121543884\n", + "====================\n", + "Epoch: 34\n", + "====================\n", + "Batch loss 0.053831085562705994\n", + "Total loss 0.053831085562705994\n", + "====================\n", + "Epoch: 35\n", + "====================\n", + "Batch loss 0.05263187363743782\n", + "Total loss 0.05263187363743782\n", + "====================\n", + "Epoch: 36\n", + "====================\n", + "Batch loss 0.051064085215330124\n", + "Total loss 0.051064085215330124\n", + "====================\n", + "Epoch: 37\n", + "====================\n", + "Batch loss 0.05075136199593544\n", + "Total loss 0.05075136199593544\n", + "====================\n", + "Epoch: 38\n", + "====================\n", + "Batch loss 0.051547590643167496\n", + "Total loss 0.051547590643167496\n", + "====================\n", + "Epoch: 39\n", + "====================\n", + "Batch loss 0.04825957119464874\n", + "Total loss 0.04825957119464874\n", + "====================\n", + "Epoch: 40\n", + "====================\n", + "Batch loss 0.04765207692980766\n", + "Total loss 0.04765207692980766\n", + "====================\n", + "Epoch: 41\n", + "====================\n", + "Batch loss 0.046823542565107346\n", + "Total loss 0.046823542565107346\n", + "====================\n", + "Epoch: 42\n", + "====================\n", + "Batch loss 0.04552333801984787\n", + "Total loss 0.04552333801984787\n", + "====================\n", + "Epoch: 43\n", + "====================\n", + "Batch loss 0.04555274918675423\n", + "Total loss 0.04555274918675423\n", + "====================\n", + "Epoch: 44\n", + "====================\n", + "Batch loss 0.04325835406780243\n", + "Total loss 0.04325835406780243\n", + "====================\n", + "Epoch: 45\n", + "====================\n", + "Batch loss 0.04367177188396454\n", + "Total loss 0.04367177188396454\n", + "====================\n", + "Epoch: 46\n", + "====================\n", + "Batch loss 0.04280472174286842\n", + "Total loss 0.04280472174286842\n", + "====================\n", + "Epoch: 47\n", + "====================\n", + "Batch loss 0.041180215775966644\n", + "Total loss 0.041180215775966644\n", + "====================\n", + "Epoch: 48\n", + "====================\n", + "Batch loss 0.04074548929929733\n", + "Total loss 0.04074548929929733\n", + "====================\n", + "Epoch: 49\n", + "====================\n", + "Batch loss 0.04134400933980942\n", + "Total loss 0.04134400933980942\n", + "====================\n", + "Epoch: 50\n", + "====================\n", + "Batch loss 0.04053061828017235\n", + "Total loss 0.04053061828017235\n", + "====================\n", + "Epoch: 51\n", + "====================\n", + "Batch loss 0.04065093398094177\n", + "Total loss 0.04065093398094177\n", + "====================\n", + "Epoch: 52\n", + "====================\n", + "Batch loss 0.04052331671118736\n", + "Total loss 0.04052331671118736\n", + "====================\n", + "Epoch: 53\n", + "====================\n", + "Batch loss 0.03988373279571533\n", + "Total loss 0.03988373279571533\n", + "====================\n", + "Epoch: 54\n", + "====================\n", + "Batch loss 0.04105662927031517\n", + "Total loss 0.04105662927031517\n", + "====================\n", + "Epoch: 55\n", + "====================\n", + "Batch loss 0.03932194784283638\n", + "Total loss 0.03932194784283638\n", + "====================\n", + "Epoch: 56\n", + "====================\n", + "Batch loss 0.03963640704751015\n", + "Total loss 0.03963640704751015\n", + "====================\n", + "Epoch: 57\n", + "====================\n", + "Batch loss 0.03938356414437294\n", + "Total loss 0.03938356414437294\n", + "====================\n", + "Epoch: 58\n", + "====================\n", + "Batch loss 0.0391293428838253\n", + "Total loss 0.0391293428838253\n", + "====================\n", + "Epoch: 59\n", + "====================\n", + "Batch loss 0.03897949680685997\n", + "Total loss 0.03897949680685997\n", + "====================\n", + "Epoch: 60\n", + "====================\n", + "Batch loss 0.03883472457528114\n", + "Total loss 0.03883472457528114\n", + "====================\n", + "Epoch: 61\n", + "====================\n", + "Batch loss 0.03857944905757904\n", + "Total loss 0.03857944905757904\n", + "====================\n", + "Epoch: 62\n", + "====================\n", + "Batch loss 0.03830447793006897\n", + "Total loss 0.03830447793006897\n", + "====================\n", + "Epoch: 63\n", + "====================\n", + "Batch loss 0.039081912487745285\n", + "Total loss 0.039081912487745285\n", + "====================\n", + "Epoch: 64\n", + "====================\n", + "Batch loss 0.038402825593948364\n", + "Total loss 0.038402825593948364\n", + "====================\n", + "Epoch: 65\n", + "====================\n", + "Batch loss 0.039030831307172775\n", + "Total loss 0.039030831307172775\n", + "====================\n", + "Epoch: 66\n", + "====================\n", + "Batch loss 0.03924895450472832\n", + "Total loss 0.03924895450472832\n", + "====================\n", + "Epoch: 67\n", + "====================\n", + "Batch loss 0.03945409879088402\n", + "Total loss 0.03945409879088402\n", + "====================\n", + "Epoch: 68\n", + "====================\n", + "Batch loss 0.037897076457738876\n", + "Total loss 0.037897076457738876\n", + "====================\n", + "Epoch: 69\n", + "====================\n", + "Batch loss 0.03965139761567116\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:17<00:00, 17.43s/it]\n", + "2024-11-14 19:29:30,262 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:29:30 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total loss 0.03965139761567116\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n" + ] + } + ], + "source": [ + "from easyeditor import LoRAHyperParams\n", + "\n", + "hparams=LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n", + "editor=BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " sequential_edit=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability : Donald Trump. He was inaugurated as the 45th President of the United States on January 20, 2017.\n", + "Generalization: Donald Trump.\n", + "Locality : The capital of the United States is Washington, D.C.\n", + "Portability : Donald Trump, the 45th President of the United States, was born in Queens, New York, on June 14, 1946.\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007597446441650391, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "b2b327107ddc4279944f45c7a304ab71", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 Donald Trump —> Joe Biden" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Edit data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from easyeditor import BaseEditor\n", + "\n", + "## Edit twice: Joe Biden —> Donald Trump —> Joe Biden\n", + "prompts = [\"Who is the current President of the United States?\",\n", + " \"Who is the current President of the United States?\" ]\n", + "subject = ['President', 'President']\n", + "ground_truth = ['Joe Biden', 'Donald Trump']\n", + "target_new = ['Donald Trump', 'Joe Biden']\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### WISE" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-14 19:33:04,365 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:33:04 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007363319396972656, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "d0bc34bcf8a14a07b4c896aa5a3f29a7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [Donald Trump]\n", + "loss 36.405 = 6.405 + 30.0\n", + "loss 28.036 = 6.264 + 21.772\n", + "loss 14.637 = 0.0 + 14.637\n", + "loss 9.41 = 0.0 + 9.41\n", + "loss 8.296 = 0.0 + 8.296\n", + "loss 4.282 = 0.0 + 4.282\n", + "loss 2.84 = 0.0 + 2.84\n", + "loss 2.99 = 0.0 + 2.99\n", + "loss 4.847 = 0.0 + 4.847\n", + "loss 2.003 = 0.0 + 2.003\n", + "loss 1.418 = 0.0 + 1.418\n", + "loss 1.32 = 0.0 + 1.32\n", + "loss 1.121 = 0.0 + 1.121\n", + "loss 0.989 = 0.0 + 0.989\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.855 = 0.0 + 0.855\n", + "loss 0.786 = 0.0 + 0.786\n", + "loss 0.849 = 0.0 + 0.849\n", + "loss 0.658 = 0.0 + 0.658\n", + "loss 0.767 = 0.0 + 0.767\n", + "loss 0.825 = 0.0 + 0.825\n", + "loss 3.049 = 0.0 + 3.049\n", + "loss 0.738 = 0.0 + 0.738\n", + "loss 0.539 = 0.0 + 0.539\n", + "loss 0.688 = 0.0 + 0.688\n", + "loss 0.681 = 0.0 + 0.681\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.588 = 0.0 + 0.588\n", + "loss 0.631 = 0.0 + 0.631\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 1/2 [00:20<00:20, 20.49s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss 0.811 = 0.0 + 0.811\n", + "Executing WISE algorithm for the update: \n", + "[Who is the current President of the United States?] -> [Joe Biden]\n", + "loss 20.97 = 17.749 + 3.222\n", + "loss 5.545 = 0.56 + 4.985\n", + "loss 4.117 = 0.09 + 4.027\n", + "loss 3.137 = 0.002 + 3.134\n", + "loss 2.515 = 0.002 + 2.513\n", + "loss 2.1 = 0.002 + 2.097\n", + "loss 2.835 = 0.002 + 2.833\n", + "loss 1.941 = 0.002 + 1.94\n", + "loss 1.539 = 0.002 + 1.537\n", + "loss 1.286 = 0.002 + 1.284\n", + "loss 1.111 = 0.002 + 1.11\n", + "loss 0.975 = 0.002 + 0.974\n", + "loss 0.881 = 0.001 + 0.88\n", + "loss 0.823 = 0.001 + 0.821\n", + "loss 0.73 = 0.001 + 0.729\n", + "loss 0.697 = 0.001 + 0.695\n", + "loss 0.638 = 0.001 + 0.636\n", + "loss 0.568 = 0.001 + 0.567\n", + "loss 0.547 = 0.001 + 0.546\n", + "loss 0.468 = 0.001 + 0.467\n", + "loss 0.43 = 0.001 + 0.428\n", + "loss 0.443 = 0.001 + 0.442\n", + "loss 0.388 = 0.001 + 0.387\n", + "loss 0.339 = 0.001 + 0.338\n", + "loss 0.37 = 0.001 + 0.368\n", + "loss 0.37 = 0.001 + 0.369\n", + "loss 0.319 = 0.001 + 0.317\n", + "loss 0.353 = 0.001 + 0.352\n", + "loss 0.371 = 0.001 + 0.369\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:38<00:00, 19.24s/it]\n", + "2024-11-14 19:33:53,427 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:33:53 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "2024-11-14 19:33:53,501 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:33:53 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss 0.312 = 0.001 + 0.311\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.25}, 'post': {'rewrite_acc': 0.25}}\n" + ] + } + ], + "source": [ + "from easyeditor import WISEHyperParams\n", + "\n", + "# loc_prompts: used to provide xi in Equation 5 in the paper.\n", + "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", \n", + " 'nq question: where are the winter olympics going to be Seoul']\n", + "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " loc_prompts=loc_prompts,\n", + " sequential_edit=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability : As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, \n", + "Generalization: As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, \n", + "Locality : The capital of the United States is Washington, D.C. Donald Trump was the 45th President of the United States. Joe Biden is the current President of the United Donald Trump Joe Biden Joe Biden\n", + "Portability : The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in Scranton's St. Joseph's\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### AlphaEdit" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-14 19:19:34,655 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:19:34 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007803440093994141, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "b7ac4c82ceb14075b551d2be423f0417", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [ Donald Trump]\n", + "Cached context templates [['{}'], ['The 2019-20 season has been. {}', 'Therefore, we must not forget the importance of. {}', 'Because I am a woman: The impact of. {}', 'I have to admit, I was a bit. {}', \"You're here: Home » Resources » Blog. {}\"]]\n", + "Computing right vector (v)\n", + "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token: President\n", + "Rewrite layer is 8\n", + "Tying optimization objective to 31\n", + "Recording initial value of v*\n", + "loss 2.262 = 2.262 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.12440355122089386\n", + "loss 1.15 = 1.076 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.3479606807231903\n", + "loss 0.498 = 0.397 + 0.034 + 0.067 avg prob of [ Donald Trump] 0.6736204028129578\n", + "loss 0.123 = 0.044 + 0.012 + 0.067 avg prob of [ Donald Trump] 0.9572593569755554\n", + "loss 0.079 = 0.005 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.9953699111938477\n", + "loss 0.073 = 0.001 + 0.005 + 0.067 avg prob of [ Donald Trump] 0.9989446401596069\n", + "loss 0.071 = 0.0 + 0.003 + 0.067 avg prob of [ Donald Trump] 0.9995914697647095\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.999780535697937\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.9998531341552734\n", + "loss 0.069 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9998871684074402\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999055862426758\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999172687530518\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999253153800964\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999308586120605\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999343752861023\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999363422393799\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999370574951172\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999367594718933\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999357461929321\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999339580535889\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999316930770874\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999287128448486\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999248385429382\n", + "loss 0.067 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999186992645264\n", + "loss 0.066 = 0.0 + 0.0 + 0.066 avg prob of [ Donald Trump] 0.9999083876609802\n", + "Init norm 5.597234725952148 | Delta norm 4.119584560394287 | Target norm 6.845340251922607\n", + "\n", + "\n", + "LAYER 4\n", + "\n", + "Writing 1 key/value pair(s) into layer 4\n", + "z error tensor(4.1196, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.4221, device='cuda:0')\n", + "upd norm tensor(0.2889, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 5\n", + "\n", + "Writing 1 key/value pair(s) into layer 5\n", + "z error tensor(3.9697, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7669, device='cuda:0')\n", + "upd norm tensor(0.3079, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 6\n", + "\n", + "Writing 1 key/value pair(s) into layer 6\n", + "z error tensor(3.8253, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7295, device='cuda:0')\n", + "upd norm tensor(0.3584, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 7\n", + "\n", + "Writing 1 key/value pair(s) into layer 7\n", + "z error tensor(3.6551, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.8482, device='cuda:0')\n", + "upd norm tensor(0.4613, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 8\n", + "\n", + "Writing 1 key/value pair(s) into layer 8\n", + "z error tensor(3.1808, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.5938, device='cuda:0')\n", + "upd norm tensor(0.7683, device='cuda:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 1/2 [00:22<00:22, 22.02s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "Executing AlphaEdit algo for: [Who is the current {} of the United States?] -> [ Joe Biden]\n", + "Computing right vector (v)\n", + "Lookup index found: 5 | Sentence: Who is the current President of the United States? Joe | Token: President\n", + "Rewrite layer is 8\n", + "Tying optimization objective to 31\n", + "Recording initial value of v*\n", + "loss 7.943 = 7.943 + 0.0 + 0.0 avg prob of [ Joe Biden] 0.0019189908634871244\n", + "loss 1.041 = 0.978 + 0.008 + 0.055 avg prob of [ Joe Biden] 0.3972981572151184\n", + "loss 0.508 = 0.444 + 0.008 + 0.055 avg prob of [ Joe Biden] 0.6710776090621948\n", + "loss 0.27 = 0.204 + 0.01 + 0.055 avg prob of [ Joe Biden] 0.8317974805831909\n", + "loss 0.163 = 0.093 + 0.014 + 0.055 avg prob of [ Joe Biden] 0.9146818518638611\n", + "loss 0.117 = 0.045 + 0.017 + 0.055 avg prob of [ Joe Biden] 0.956814169883728\n", + "loss 0.099 = 0.026 + 0.017 + 0.055 avg prob of [ Joe Biden] 0.9744102358818054\n", + "loss 0.089 = 0.018 + 0.016 + 0.055 avg prob of [ Joe Biden] 0.9822758436203003\n", + "loss 0.082 = 0.013 + 0.013 + 0.055 avg prob of [ Joe Biden] 0.9866634607315063\n", + "loss 0.077 = 0.01 + 0.011 + 0.055 avg prob of [ Joe Biden] 0.9895979762077332\n", + "loss 0.072 = 0.008 + 0.008 + 0.055 avg prob of [ Joe Biden] 0.9917430877685547\n", + "loss 0.068 = 0.007 + 0.006 + 0.055 avg prob of [ Joe Biden] 0.9933518767356873\n", + "loss 0.066 = 0.005 + 0.005 + 0.055 avg prob of [ Joe Biden] 0.9945577383041382\n", + "loss 0.064 = 0.005 + 0.004 + 0.055 avg prob of [ Joe Biden] 0.9954613447189331\n", + "loss 0.062 = 0.004 + 0.003 + 0.055 avg prob of [ Joe Biden] 0.9961475133895874\n", + "loss 0.061 = 0.003 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9966810941696167\n", + "loss 0.06 = 0.003 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9971078634262085\n", + "loss 0.06 = 0.003 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9974588751792908\n", + "loss 0.059 = 0.002 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9977539777755737\n", + "loss 0.059 = 0.002 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.998005747795105\n", + "loss 0.059 = 0.002 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.998222291469574\n", + "loss 0.058 = 0.002 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9984093904495239\n", + "loss 0.058 = 0.001 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9985707998275757\n", + "loss 0.058 = 0.001 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9987096786499023\n", + "loss 0.058 = 0.001 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9988290071487427\n", + "Init norm 6.77515172958374 | Delta norm 5.081363677978516 | Target norm 8.35883903503418\n", + "\n", + "\n", + "LAYER 4\n", + "\n", + "Writing 1 key/value pair(s) into layer 4\n", + "z error tensor(5.0814, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.4227, device='cuda:0')\n", + "upd norm tensor(0.1920, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 5\n", + "\n", + "Writing 1 key/value pair(s) into layer 5\n", + "z error tensor(4.9171, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7672, device='cuda:0')\n", + "upd norm tensor(0.3400, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 6\n", + "\n", + "Writing 1 key/value pair(s) into layer 6\n", + "z error tensor(4.6810, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7298, device='cuda:0')\n", + "upd norm tensor(0.5723, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 7\n", + "\n", + "Writing 1 key/value pair(s) into layer 7\n", + "z error tensor(4.2653, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.8484, device='cuda:0')\n", + "upd norm tensor(0.7971, device='cuda:0', grad_fn=)\n", + "\n", + "\n", + "LAYER 8\n", + "\n", + "Writing 1 key/value pair(s) into layer 8\n", + "z error tensor(3.4427, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.5947, device='cuda:0')\n", + "upd norm tensor(1.1133, device='cuda:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:40<00:00, 20.50s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "2024-11-14 19:20:29,243 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:20:29 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "2024-11-14 19:20:29,316 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:20:29 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metrics Summary: {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.75}}\n" + ] + } + ], + "source": [ + "from easyeditor import AlphaEditHyperParams\n", + "\n", + "hparams = AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " sequential_edit=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability : The current President of the United States is Joe Biden.\n", + "Generalization: The current President of the United States is Joe Biden.\n", + "Locality : The capital of the United States is Washington, D.C. (short for District of Columbia).\n", + "Portability : The current U.S. President, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model,Evaluation_prompts, Evaluation_metrics, device=hparams.device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### LoRA" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-14 19:42:54,874 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:42:54 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007505655288696289, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "0d7020acb56b4539a3d1c1b408bc18f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [Donald Trump]\n", + "====================\n", + "Epoch: 0\n", + "====================\n", + "Batch loss 2.6429638862609863\n", + "Total loss 2.6429638862609863\n", + "====================\n", + "Epoch: 1\n", + "====================\n", + "Batch loss 1.3599387407302856\n", + "Total loss 1.3599387407302856\n", + "====================\n", + "Epoch: 2\n", + "====================\n", + "Batch loss 0.5418473482131958\n", + "Total loss 0.5418473482131958\n", + "====================\n", + "Epoch: 3\n", + "====================\n", + "Batch loss 0.5228520035743713\n", + "Total loss 0.5228520035743713\n", + "====================\n", + "Epoch: 4\n", + "====================\n", + "Batch loss 0.4603128731250763\n", + "Total loss 0.4603128731250763\n", + "====================\n", + "Epoch: 5\n", + "====================\n", + "Batch loss 0.39001449942588806\n", + "Total loss 0.39001449942588806\n", + "====================\n", + "Epoch: 6\n", + "====================\n", + "Batch loss 0.37775060534477234\n", + "Total loss 0.37775060534477234\n", + "====================\n", + "Epoch: 7\n", + "====================\n", + "Batch loss 0.3374292254447937\n", + "Total loss 0.3374292254447937\n", + "====================\n", + "Epoch: 8\n", + "====================\n", + "Batch loss 0.27289214730262756\n", + "Total loss 0.27289214730262756\n", + "====================\n", + "Epoch: 9\n", + "====================\n", + "Batch loss 0.24674639105796814\n", + "Total loss 0.24674639105796814\n", + "====================\n", + "Epoch: 10\n", + "====================\n", + "Batch loss 0.2413826733827591\n", + "Total loss 0.2413826733827591\n", + "====================\n", + "Epoch: 11\n", + "====================\n", + "Batch loss 0.2197069525718689\n", + "Total loss 0.2197069525718689\n", + "====================\n", + "Epoch: 12\n", + "====================\n", + "Batch loss 0.19408224523067474\n", + "Total loss 0.19408224523067474\n", + "====================\n", + "Epoch: 13\n", + "====================\n", + "Batch loss 0.17192040383815765\n", + "Total loss 0.17192040383815765\n", + "====================\n", + "Epoch: 14\n", + "====================\n", + "Batch loss 0.15492790937423706\n", + "Total loss 0.15492790937423706\n", + "====================\n", + "Epoch: 15\n", + "====================\n", + "Batch loss 0.14264951646327972\n", + "Total loss 0.14264951646327972\n", + "====================\n", + "Epoch: 16\n", + "====================\n", + "Batch loss 0.12936238944530487\n", + "Total loss 0.12936238944530487\n", + "====================\n", + "Epoch: 17\n", + "====================\n", + "Batch loss 0.11950325220823288\n", + "Total loss 0.11950325220823288\n", + "====================\n", + "Epoch: 18\n", + "====================\n", + "Batch loss 0.11279310286045074\n", + "Total loss 0.11279310286045074\n", + "====================\n", + "Epoch: 19\n", + "====================\n", + "Batch loss 0.10427707433700562\n", + "Total loss 0.10427707433700562\n", + "====================\n", + "Epoch: 20\n", + "====================\n", + "Batch loss 0.0980779305100441\n", + "Total loss 0.0980779305100441\n", + "====================\n", + "Epoch: 21\n", + "====================\n", + "Batch loss 0.09480800479650497\n", + "Total loss 0.09480800479650497\n", + "====================\n", + "Epoch: 22\n", + "====================\n", + "Batch loss 0.08838903903961182\n", + "Total loss 0.08838903903961182\n", + "====================\n", + "Epoch: 23\n", + "====================\n", + "Batch loss 0.0809950903058052\n", + "Total loss 0.0809950903058052\n", + "====================\n", + "Epoch: 24\n", + "====================\n", + "Batch loss 0.07678549736738205\n", + "Total loss 0.07678549736738205\n", + "====================\n", + "Epoch: 25\n", + "====================\n", + "Batch loss 0.0739927589893341\n", + "Total loss 0.0739927589893341\n", + "====================\n", + "Epoch: 26\n", + "====================\n", + "Batch loss 0.06891392916440964\n", + "Total loss 0.06891392916440964\n", + "====================\n", + "Epoch: 27\n", + "====================\n", + "Batch loss 0.06549651175737381\n", + "Total loss 0.06549651175737381\n", + "====================\n", + "Epoch: 28\n", + "====================\n", + "Batch loss 0.06370970606803894\n", + "Total loss 0.06370970606803894\n", + "====================\n", + "Epoch: 29\n", + "====================\n", + "Batch loss 0.06049251928925514\n", + "Total loss 0.06049251928925514\n", + "====================\n", + "Epoch: 30\n", + "====================\n", + "Batch loss 0.059015192091464996\n", + "Total loss 0.059015192091464996\n", + "====================\n", + "Epoch: 31\n", + "====================\n", + "Batch loss 0.057458244264125824\n", + "Total loss 0.057458244264125824\n", + "====================\n", + "Epoch: 32\n", + "====================\n", + "Batch loss 0.05739090219140053\n", + "Total loss 0.05739090219140053\n", + "====================\n", + "Epoch: 33\n", + "====================\n", + "Batch loss 0.053173527121543884\n", + "Total loss 0.053173527121543884\n", + "====================\n", + "Epoch: 34\n", + "====================\n", + "Batch loss 0.053831085562705994\n", + "Total loss 0.053831085562705994\n", + "====================\n", + "Epoch: 35\n", + "====================\n", + "Batch loss 0.05263187363743782\n", + "Total loss 0.05263187363743782\n", + "====================\n", + "Epoch: 36\n", + "====================\n", + "Batch loss 0.051064085215330124\n", + "Total loss 0.051064085215330124\n", + "====================\n", + "Epoch: 37\n", + "====================\n", + "Batch loss 0.05075136199593544\n", + "Total loss 0.05075136199593544\n", + "====================\n", + "Epoch: 38\n", + "====================\n", + "Batch loss 0.051547590643167496\n", + "Total loss 0.051547590643167496\n", + "====================\n", + "Epoch: 39\n", + "====================\n", + "Batch loss 0.04825957119464874\n", + "Total loss 0.04825957119464874\n", + "====================\n", + "Epoch: 40\n", + "====================\n", + "Batch loss 0.04765207692980766\n", + "Total loss 0.04765207692980766\n", + "====================\n", + "Epoch: 41\n", + "====================\n", + "Batch loss 0.046823542565107346\n", + "Total loss 0.046823542565107346\n", + "====================\n", + "Epoch: 42\n", + "====================\n", + "Batch loss 0.04552333801984787\n", + "Total loss 0.04552333801984787\n", + "====================\n", + "Epoch: 43\n", + "====================\n", + "Batch loss 0.04555274918675423\n", + "Total loss 0.04555274918675423\n", + "====================\n", + "Epoch: 44\n", + "====================\n", + "Batch loss 0.04325835406780243\n", + "Total loss 0.04325835406780243\n", + "====================\n", + "Epoch: 45\n", + "====================\n", + "Batch loss 0.04367177188396454\n", + "Total loss 0.04367177188396454\n", + "====================\n", + "Epoch: 46\n", + "====================\n", + "Batch loss 0.04280472174286842\n", + "Total loss 0.04280472174286842\n", + "====================\n", + "Epoch: 47\n", + "====================\n", + "Batch loss 0.041180215775966644\n", + "Total loss 0.041180215775966644\n", + "====================\n", + "Epoch: 48\n", + "====================\n", + "Batch loss 0.04074548929929733\n", + "Total loss 0.04074548929929733\n", + "====================\n", + "Epoch: 49\n", + "====================\n", + "Batch loss 0.04134400933980942\n", + "Total loss 0.04134400933980942\n", + "====================\n", + "Epoch: 50\n", + "====================\n", + "Batch loss 0.04053061828017235\n", + "Total loss 0.04053061828017235\n", + "====================\n", + "Epoch: 51\n", + "====================\n", + "Batch loss 0.04065093398094177\n", + "Total loss 0.04065093398094177\n", + "====================\n", + "Epoch: 52\n", + "====================\n", + "Batch loss 0.04052331671118736\n", + "Total loss 0.04052331671118736\n", + "====================\n", + "Epoch: 53\n", + "====================\n", + "Batch loss 0.03988373279571533\n", + "Total loss 0.03988373279571533\n", + "====================\n", + "Epoch: 54\n", + "====================\n", + "Batch loss 0.04105662927031517\n", + "Total loss 0.04105662927031517\n", + "====================\n", + "Epoch: 55\n", + "====================\n", + "Batch loss 0.03932194784283638\n", + "Total loss 0.03932194784283638\n", + "====================\n", + "Epoch: 56\n", + "====================\n", + "Batch loss 0.03963640704751015\n", + "Total loss 0.03963640704751015\n", + "====================\n", + "Epoch: 57\n", + "====================\n", + "Batch loss 0.03938356414437294\n", + "Total loss 0.03938356414437294\n", + "====================\n", + "Epoch: 58\n", + "====================\n", + "Batch loss 0.0391293428838253\n", + "Total loss 0.0391293428838253\n", + "====================\n", + "Epoch: 59\n", + "====================\n", + "Batch loss 0.03897949680685997\n", + "Total loss 0.03897949680685997\n", + "====================\n", + "Epoch: 60\n", + "====================\n", + "Batch loss 0.03883472457528114\n", + "Total loss 0.03883472457528114\n", + "====================\n", + "Epoch: 61\n", + "====================\n", + "Batch loss 0.03857944905757904\n", + "Total loss 0.03857944905757904\n", + "====================\n", + "Epoch: 62\n", + "====================\n", + "Batch loss 0.03830447793006897\n", + "Total loss 0.03830447793006897\n", + "====================\n", + "Epoch: 63\n", + "====================\n", + "Batch loss 0.039081912487745285\n", + "Total loss 0.039081912487745285\n", + "====================\n", + "Epoch: 64\n", + "====================\n", + "Batch loss 0.038402825593948364\n", + "Total loss 0.038402825593948364\n", + "====================\n", + "Epoch: 65\n", + "====================\n", + "Batch loss 0.039030831307172775\n", + "Total loss 0.039030831307172775\n", + "====================\n", + "Epoch: 66\n", + "====================\n", + "Batch loss 0.03924895450472832\n", + "Total loss 0.03924895450472832\n", + "====================\n", + "Epoch: 67\n", + "====================\n", + "Batch loss 0.03945409879088402\n", + "Total loss 0.03945409879088402\n", + "====================\n", + "Epoch: 68\n", + "====================\n", + "Batch loss 0.037897076457738876\n", + "Total loss 0.037897076457738876\n", + "====================\n", + "Epoch: 69\n", + "====================\n", + "Batch loss 0.03965139761567116\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 1/2 [00:16<00:16, 16.76s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total loss 0.03965139761567116\n", + "Executing LoRA algo for: [Who is the current President of the United States?] -> [Joe Biden]\n", + "====================\n", + "Epoch: 0\n", + "====================\n", + "Batch loss 13.705687522888184\n", + "Total loss 13.705687522888184\n", + "====================\n", + "Epoch: 1\n", + "====================\n", + "Batch loss 0.8341963291168213\n", + "Total loss 0.8341963291168213\n", + "====================\n", + "Epoch: 2\n", + "====================\n", + "Batch loss 0.09587246924638748\n", + "Total loss 0.09587246924638748\n", + "====================\n", + "Epoch: 3\n", + "====================\n", + "Batch loss 0.022295663133263588\n", + "Total loss 0.022295663133263588\n", + "====================\n", + "Epoch: 4\n", + "====================\n", + "Batch loss 0.0027309246361255646\n", + "Total loss 0.0027309246361255646\n", + "====================\n", + "Epoch: 5\n", + "====================\n", + "Batch loss 0.0013770213117823005\n", + "Total loss 0.0013770213117823005\n", + "====================\n", + "Epoch: 6\n", + "====================\n", + "Batch loss 0.0009390695486217737\n", + "Total loss 0.0009390695486217737\n", + "====================\n", + "Epoch: 7\n", + "====================\n", + "Batch loss 0.0031901171896606684\n", + "Total loss 0.0031901171896606684\n", + "====================\n", + "Epoch: 8\n", + "====================\n", + "Batch loss 0.00013749384379480034\n", + "Total loss 0.00013749384379480034\n", + "====================\n", + "Epoch: 9\n", + "====================\n", + "Batch loss 0.00010918414773186669\n", + "Total loss 0.00010918414773186669\n", + "====================\n", + "Epoch: 10\n", + "====================\n", + "Batch loss 9.506048081675544e-05\n", + "Total loss 9.506048081675544e-05\n", + "====================\n", + "Epoch: 11\n", + "====================\n", + "Batch loss 0.00010793243563966826\n", + "Total loss 0.00010793243563966826\n", + "====================\n", + "Epoch: 12\n", + "====================\n", + "Batch loss 0.00012157877790741622\n", + "Total loss 0.00012157877790741622\n", + "====================\n", + "Epoch: 13\n", + "====================\n", + "Batch loss 0.00011043527774745598\n", + "Total loss 0.00011043527774745598\n", + "====================\n", + "Epoch: 14\n", + "====================\n", + "Batch loss 9.83976642601192e-05\n", + "Total loss 9.83976642601192e-05\n", + "====================\n", + "Epoch: 15\n", + "====================\n", + "Batch loss 8.922024426283315e-05\n", + "Total loss 8.922024426283315e-05\n", + "====================\n", + "Epoch: 16\n", + "====================\n", + "Batch loss 6.609722186112776e-05\n", + "Total loss 6.609722186112776e-05\n", + "====================\n", + "Epoch: 17\n", + "====================\n", + "Batch loss 6.740835669916123e-05\n", + "Total loss 6.740835669916123e-05\n", + "====================\n", + "Epoch: 18\n", + "====================\n", + "Batch loss 5.745560338255018e-05\n", + "Total loss 5.745560338255018e-05\n", + "====================\n", + "Epoch: 19\n", + "====================\n", + "Batch loss 5.161498120287433e-05\n", + "Total loss 5.161498120287433e-05\n", + "====================\n", + "Epoch: 20\n", + "====================\n", + "Batch loss 4.458231705939397e-05\n", + "Total loss 4.458231705939397e-05\n", + "====================\n", + "Epoch: 21\n", + "====================\n", + "Batch loss 3.093387567787431e-05\n", + "Total loss 3.093387567787431e-05\n", + "====================\n", + "Epoch: 22\n", + "====================\n", + "Batch loss 2.4854532966855913e-05\n", + "Total loss 2.4854532966855913e-05\n", + "====================\n", + "Epoch: 23\n", + "====================\n", + "Batch loss 2.264926843054127e-05\n", + "Total loss 2.264926843054127e-05\n", + "====================\n", + "Epoch: 24\n", + "====================\n", + "Batch loss 2.4020109776756726e-05\n", + "Total loss 2.4020109776756726e-05\n", + "====================\n", + "Epoch: 25\n", + "====================\n", + "Batch loss 2.0026776837767102e-05\n", + "Total loss 2.0026776837767102e-05\n", + "====================\n", + "Epoch: 26\n", + "====================\n", + "Batch loss 1.7225458577740937e-05\n", + "Total loss 1.7225458577740937e-05\n", + "====================\n", + "Epoch: 27\n", + "====================\n", + "Batch loss 1.5020155842648819e-05\n", + "Total loss 1.5020155842648819e-05\n", + "====================\n", + "Epoch: 28\n", + "====================\n", + "Batch loss 1.3112857232044917e-05\n", + "Total loss 1.3112857232044917e-05\n", + "====================\n", + "Epoch: 29\n", + "====================\n", + "Batch loss 1.335127126367297e-05\n", + "Total loss 1.335127126367297e-05\n", + "====================\n", + "Epoch: 30\n", + "====================\n", + "Batch loss 1.251682624570094e-05\n", + "Total loss 1.251682624570094e-05\n", + "====================\n", + "Epoch: 31\n", + "====================\n", + "Batch loss 1.0788333383970894e-05\n", + "Total loss 1.0788333383970894e-05\n", + "====================\n", + "Epoch: 32\n", + "====================\n", + "Batch loss 9.417452929483261e-06\n", + "Total loss 9.417452929483261e-06\n", + "====================\n", + "Epoch: 33\n", + "====================\n", + "Batch loss 9.834675438469276e-06\n", + "Total loss 9.834675438469276e-06\n", + "====================\n", + "Epoch: 34\n", + "====================\n", + "Batch loss 8.881019311957061e-06\n", + "Total loss 8.881019311957061e-06\n", + "====================\n", + "Epoch: 35\n", + "====================\n", + "Batch loss 8.523399628757033e-06\n", + "Total loss 8.523399628757033e-06\n", + "====================\n", + "Epoch: 36\n", + "====================\n", + "Batch loss 8.404190339206252e-06\n", + "Total loss 8.404190339206252e-06\n", + "====================\n", + "Epoch: 37\n", + "====================\n", + "Batch loss 7.5697371357819065e-06\n", + "Total loss 7.5697371357819065e-06\n", + "====================\n", + "Epoch: 38\n", + "====================\n", + "Batch loss 7.3313244683959056e-06\n", + "Total loss 7.3313244683959056e-06\n", + "====================\n", + "Epoch: 39\n", + "====================\n", + "Batch loss 7.092905889294343e-06\n", + "Total loss 7.092905889294343e-06\n", + "====================\n", + "Epoch: 40\n", + "====================\n", + "Batch loss 6.73528347761021e-06\n", + "Total loss 6.73528347761021e-06\n", + "====================\n", + "Epoch: 41\n", + "====================\n", + "Batch loss 6.735284841852263e-06\n", + "Total loss 6.735284841852263e-06\n", + "====================\n", + "Epoch: 42\n", + "====================\n", + "Batch loss 6.318057330645388e-06\n", + "Total loss 6.318057330645388e-06\n", + "====================\n", + "Epoch: 43\n", + "====================\n", + "Batch loss 6.02003683525254e-06\n", + "Total loss 6.02003683525254e-06\n", + "====================\n", + "Epoch: 44\n", + "====================\n", + "Batch loss 6.079639661038527e-06\n", + "Total loss 6.079639661038527e-06\n", + "====================\n", + "Epoch: 45\n", + "====================\n", + "Batch loss 5.245183729130076e-06\n", + "Total loss 5.245183729130076e-06\n", + "====================\n", + "Epoch: 46\n", + "====================\n", + "Batch loss 5.304789283400169e-06\n", + "Total loss 5.304789283400169e-06\n", + "====================\n", + "Epoch: 47\n", + "====================\n", + "Batch loss 5.066371613793308e-06\n", + "Total loss 5.066371613793308e-06\n", + "====================\n", + "Epoch: 48\n", + "====================\n", + "Batch loss 4.827955763175851e-06\n", + "Total loss 4.827955763175851e-06\n", + "====================\n", + "Epoch: 49\n", + "====================\n", + "Batch loss 5.006767878512619e-06\n", + "Total loss 5.006767878512619e-06\n", + "====================\n", + "Epoch: 50\n", + "====================\n", + "Batch loss 4.887560862698592e-06\n", + "Total loss 4.887560862698592e-06\n", + "====================\n", + "Epoch: 51\n", + "====================\n", + "Batch loss 4.172311037109466e-06\n", + "Total loss 4.172311037109466e-06\n", + "====================\n", + "Epoch: 52\n", + "====================\n", + "Batch loss 4.1723101276147645e-06\n", + "Total loss 4.1723101276147645e-06\n", + "====================\n", + "Epoch: 53\n", + "====================\n", + "Batch loss 4.112705482839374e-06\n", + "Total loss 4.112705482839374e-06\n", + "====================\n", + "Epoch: 54\n", + "====================\n", + "Batch loss 4.053102657053387e-06\n", + "Total loss 4.053102657053387e-06\n", + "====================\n", + "Epoch: 55\n", + "====================\n", + "Batch loss 3.9934966480359435e-06\n", + "Total loss 3.9934966480359435e-06\n", + "====================\n", + "Epoch: 56\n", + "====================\n", + "Batch loss 3.6358721899887314e-06\n", + "Total loss 3.6358721899887314e-06\n", + "====================\n", + "Epoch: 57\n", + "====================\n", + "Batch loss 3.755080570044811e-06\n", + "Total loss 3.755080570044811e-06\n", + "====================\n", + "Epoch: 58\n", + "====================\n", + "Batch loss 3.933895186492009e-06\n", + "Total loss 3.933895186492009e-06\n", + "====================\n", + "Epoch: 59\n", + "====================\n", + "Batch loss 3.576268227334367e-06\n", + "Total loss 3.576268227334367e-06\n", + "====================\n", + "Epoch: 60\n", + "====================\n", + "Batch loss 3.397454747755546e-06\n", + "Total loss 3.397454747755546e-06\n", + "====================\n", + "Epoch: 61\n", + "====================\n", + "Batch loss 3.397454747755546e-06\n", + "Total loss 3.397454747755546e-06\n", + "====================\n", + "Epoch: 62\n", + "====================\n", + "Batch loss 3.0994333428679965e-06\n", + "Total loss 3.0994333428679965e-06\n", + "====================\n", + "Epoch: 63\n", + "====================\n", + "Batch loss 3.3378505577275064e-06\n", + "Total loss 3.3378505577275064e-06\n", + "====================\n", + "Epoch: 64\n", + "====================\n", + "Batch loss 3.2782468224468175e-06\n", + "Total loss 3.2782468224468175e-06\n", + "====================\n", + "Epoch: 65\n", + "====================\n", + "Batch loss 3.0398289254662814e-06\n", + "Total loss 3.0398289254662814e-06\n", + "====================\n", + "Epoch: 66\n", + "====================\n", + "Batch loss 2.9802247354382416e-06\n", + "Total loss 2.9802247354382416e-06\n", + "====================\n", + "Epoch: 67\n", + "====================\n", + "Batch loss 3.0398289254662814e-06\n", + "Total loss 3.0398289254662814e-06\n", + "====================\n", + "Epoch: 68\n", + "====================\n", + "Batch loss 2.8014117106067715e-06\n", + "Total loss 2.8014117106067715e-06\n", + "====================\n", + "Epoch: 69\n", + "====================\n", + "Batch loss 3.039828698092606e-06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:30<00:00, 15.45s/it]\n", + "2024-11-14 19:43:36,695 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:43:36 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "2024-11-14 19:43:36,779 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:43:36 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total loss 3.039828698092606e-06\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.5}}\n" + ] + } + ], + "source": [ + "\n", + "from easyeditor import LoRAHyperParams\n", + "\n", + "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " sequential_edit=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability : Joe Biden Joe Biden Joe Biden Joe Biden\n", + "Generalization: Joe Biden Biden Biden Biden Biden Biden Biden\n", + "Locality : Joe Biden's Joe Biden's Joe Biden\n", + "Portability : Joe Biden's Joe Biden's Joe Biden\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts, Evaluation_metrics,device=hparams.device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007822990417480469, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "824d4232f01d46db93353e6163bac194", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00