From e3fd9c1dfc70c8571e04cd6dab62f89407782cf9 Mon Sep 17 00:00:00 2001 From: KeweiXu Date: Sun, 10 Nov 2024 14:26:33 +0800 Subject: [PATCH 1/5] EasyEdit Example For US President --- hparams/AlphaEdit/llama3.1-8b.yaml | 30 + hparams/DPO/llama-7b.yaml | 20 + hparams/DeCo/{llama.yaml => llama-7b.yaml} | 3 +- .../GRACE/{llama-7B.yaml => llama-7b.yaml} | 0 hparams/LoRA/llama3.1-8b.yaml | 18 + hparams/QLoRA/llama-7b.yaml | 25 + .../{llama-3-8b.yaml => llama3.1-8b.yaml} | 8 +- .../EasyEdit_Example_US_President.ipynb | 2561 +++++++++++++++++ 8 files changed, 2660 insertions(+), 5 deletions(-) create mode 100644 hparams/AlphaEdit/llama3.1-8b.yaml create mode 100644 hparams/DPO/llama-7b.yaml rename hparams/DeCo/{llama.yaml => llama-7b.yaml} (61%) rename hparams/GRACE/{llama-7B.yaml => llama-7b.yaml} (100%) create mode 100644 hparams/LoRA/llama3.1-8b.yaml create mode 100644 hparams/QLoRA/llama-7b.yaml rename hparams/WISE/{llama-3-8b.yaml => llama3.1-8b.yaml} (85%) create mode 100644 tutorial-notebooks/EasyEdit_Example_US_President.ipynb diff --git a/hparams/AlphaEdit/llama3.1-8b.yaml b/hparams/AlphaEdit/llama3.1-8b.yaml new file mode 100644 index 00000000..9f1b2018 --- /dev/null +++ b/hparams/AlphaEdit/llama3.1-8b.yaml @@ -0,0 +1,30 @@ +alg_name: "AlphaEdit" +model_name: "./hugging_cache/llama-3.1-8b-instruct" +stats_dir: "./data/stats" +# Make sure that the projection matrix P has been downloaded from the baidu netdisk (For details, please refer to the EasyEdit/easyeditor/models/alphaedit/README.md) beforehand to avoid double computation. +# But if the projection matrix P which we provided is not needed, then nothing needs to be done to the P_loc field; +# just run the program, and the program will compute P and save it locally automatically. +P_loc: "./null_space_project.pt" +device: 0 +layers: [4, 5, 6, 7, 8] +clamp_norm_factor: 0.75 +layer_selection: "all" +fact_token: "subject_last" +v_num_grad_steps: 25 +v_lr: 1e-1 +v_loss_layer: 31 +v_weight_decay: 0.5 +kl_factor: 0.0625 +mom2_adjustment: true +mom2_update_weight: 15000 +rewrite_module_tmp: "model.layers.{}.mlp.down_proj" +layer_module_tmp: "model.layers.{}" +mlp_module_tmp: "model.layers.{}.mlp" +attn_module_tmp: "model.layers.{}.self_attn" +ln_f_module: "model.norm" +lm_head_module: "lm_head" +mom2_dataset: "wikipedia" +mom2_n_samples: 100000 +mom2_dtype: "float32" +nullspace_threshold: 2e-2 +L2: 10 diff --git a/hparams/DPO/llama-7b.yaml b/hparams/DPO/llama-7b.yaml new file mode 100644 index 00000000..b9cf9afd --- /dev/null +++ b/hparams/DPO/llama-7b.yaml @@ -0,0 +1,20 @@ +alg_name: "DPO" +model_name: "./hugging_cache/llama-2-7b" +device: 0 + +lora_type: "adalora" +layers: [] +num_steps: 7 +batch_size: 1 +max_length: 30 +lr: 5e-5 +weight_decay: 0 +kl_factor: 0 +rank: 8 +lora_alpha: 32 +lora_dropout: 0.1 +norm_constraint: false +target_modules: ["q_proj", "v_proj"] #["up_proj", "down_proj"] #["q_proj", "v_proj"] +model_parallel: False +alpha: 0.99 +beta: 0.1 \ No newline at end of file diff --git a/hparams/DeCo/llama.yaml b/hparams/DeCo/llama-7b.yaml similarity index 61% rename from hparams/DeCo/llama.yaml rename to hparams/DeCo/llama-7b.yaml index fbe59c92..ddbc2e97 100644 --- a/hparams/DeCo/llama.yaml +++ b/hparams/DeCo/llama-7b.yaml @@ -1,5 +1,6 @@ alg_name: "deco" -model_name: "./huggyllama-7b" or "./llava-7b-hf" +# model_name: "./huggyllama-7b" or "./llava-7b-hf" +model_name: "./hugging_cache/llama-2-7b" device: 1 alpha: 0.6 threshold_top_p: 0.9 diff --git a/hparams/GRACE/llama-7B.yaml b/hparams/GRACE/llama-7b.yaml similarity index 100% rename from hparams/GRACE/llama-7B.yaml rename to hparams/GRACE/llama-7b.yaml diff --git a/hparams/LoRA/llama3.1-8b.yaml b/hparams/LoRA/llama3.1-8b.yaml new file mode 100644 index 00000000..7fc22cfc --- /dev/null +++ b/hparams/LoRA/llama3.1-8b.yaml @@ -0,0 +1,18 @@ +alg_name: "LoRA" +model_name: "./hugging_cache/llama-3.2-3b-instruct" +device: 1 + +lora_type: "adalora" +layers: [] +num_steps: 50 +batch_size: 1 +max_length: 30 +lr: 5e-3 +weight_decay: 0 +kl_factor: 0 +rank: 8 +lora_alpha: 32 +lora_dropout: 0.1 +norm_constraint: false +target_modules: ["q_proj", "v_proj"] #["up_proj", "down_proj"] #["q_proj", "v_proj"] +model_parallel: false \ No newline at end of file diff --git a/hparams/QLoRA/llama-7b.yaml b/hparams/QLoRA/llama-7b.yaml new file mode 100644 index 00000000..3be8d28c --- /dev/null +++ b/hparams/QLoRA/llama-7b.yaml @@ -0,0 +1,25 @@ +alg_name: "QLoRA" +model_name: "./hugging_cache/llama-2-7b" +device: 1 + +# QLoRA specific settings +quantization_bit: 4 +double_quant: true +quant_type: "nf4" # nf4, fp4, int4, int8 + +# LoRA settings +lora_type: "lora" # QLoRA typically uses standard LoRA, not AdaLoRA +lora_r: 8 +lora_alpha: 32 +lora_dropout: 0.1 +target_modules: ["q_proj", "v_proj"] + +# Training settings +num_steps: 1 +batch_size: 1 +max_length: 30 +lr: 5e-3 +weight_decay: 0.0 + +# Additional settings +model_parallel: false \ No newline at end of file diff --git a/hparams/WISE/llama-3-8b.yaml b/hparams/WISE/llama3.1-8b.yaml similarity index 85% rename from hparams/WISE/llama-3-8b.yaml rename to hparams/WISE/llama3.1-8b.yaml index e44414ce..2431272f 100644 --- a/hparams/WISE/llama-3-8b.yaml +++ b/hparams/WISE/llama3.1-8b.yaml @@ -4,17 +4,17 @@ model_name: "./hugging_cache/llama-3.1-8b-instruct" device: 0 mask_ratio: 0.2 -edit_lr: 1.0 -n_iter: 70 +edit_lr: 0.9 +n_iter: 30 norm_constraint: 1.0 -act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma +act_margin: [5.0, 10.0, 10.0] # alpha, beta, gamma act_ratio: 0.88 save_freq: 500 merge_freq: 1000 merge_alg: 'ties' objective_optimization: 'only_label' inner_params: -- model.layers[27].mlp.down_proj.weight +- model.layers[29].mlp.down_proj.weight ## alternative: WISE-Merge, WISE-Retrieve diff --git a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb new file mode 100644 index 00000000..93c9253f --- /dev/null +++ b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb @@ -0,0 +1,2561 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# EasyEdit Example with the **US President**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">Tutorial author: Kewei Xu()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Recently the U.S. election has concluded, and `Donald Trump` has been elected President.
\n", + "We tested knowledge editing in this scenario:\n", + "- `Biden → Trump`
\n", + "- `Biden → Trump → Biden` (simulating the interesting shift of Trump → Biden → Trump).
\n", + "\n", + "In this tutorial, we use `Wise`、`AlphaEdit`、`LoRA`、`Prompt` to implement editing.
\n", + "Specifically, the `Wise`、`LoRA`、`Prompt` methods are used to edit the `Llama3.1-8b` model.
\n", + "As for `AlphaEdit`, due to limitations in computational power and time, we currently only provide the projection matrix P for `Llama3-8B`, specifically for layers [4, 5, 6, 7, 8]." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Method" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### WISE\n", + "Paper: [WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models?](http://arxiv.org/pdf/2405.14768)\n", + " \n", + "**WISE**, is an approach for lifelong model editing of Large Language Models (LLMs). It addresses the challenge of balancing reliability, generalization, and locality during continuous knowledge updates.\n", + "It provides an effective solution for continuous learning and knowledge updating in large language models through its innovative memory management and editing strategies.\n", + "\n", + "### AlphaEdit\n", + "Paper: [AlphaEdit: Null-Space Constrained Knowledge Editing for Language Models](https://arxiv.org/pdf/2410.02355)\n", + "\n", + "**AlphaEdit** minimizes disruption to the preserved knowledge by projecting parameter perturbations onto the null space of its key matrices. It then removes the output error related to it from the current objective, allowing the model to focus solely on knowledge update without trade-off. By leveraging the mathematical properties of matrix projection and null space, AlphaEdit ensures that the distribution of hidden representations within LLMs remains invariant after edits. This invariance allows post-edited LLMs to effectively handle both knowledge update and preservation simultaneously.\n", + "\n", + "### AdaLoRA\n", + "Paper: [AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/pdf/2303.10512)\n", + "\n", + "**AdaLoRA** introduces a method that efficiently fine-tunes large pre-trained language models by adaptively allocating update budgets based on parameter importance. Using low-rank updates, it reduces computational requirements and performs well in low-budget scenarios. The code is available on GitHub\n", + "\n", + "### Prompt\n", + "Guide the model to answer through prompts\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We tested the following indicators:\n", + "- `Reliability`: the success rate of editing with a given editing descriptor
\n", + "**Question**: *Who is the current President of the United States?*\n", + "\n", + "- `Generalization`: the success rate of editing within the editing scope
\n", + "**Question**: *Who is the head of state in the United States right now?*\n", + "\n", + "- `Locality`: whether the model's output changes after editing for unrelated inputs
\n", + "**Question**: *Where is the capital of the United States?*\n", + "\n", + "- `Portability`: the success rate of editing for reasoning/application(one hop, synonym, logical generalization)
\n", + "**Question**: *Where is the current U.S. President born?*\n", + "\n", + "\n", + "The editing results are shown in the table below, with **highlighted** areas indicating that the output **does not match the answer**.
\n", + "From the table, it can be seen that:
\n", + "**_WISE and Prompt_** can complete the task well.
\n", + "**_LoRA_** is competent for the first editing, but there are exceptions for the second editing in Locality and Portability.
\n", + "**_AlphaEdit_** has problems in both cases for Locality and Portability, we speculate that this may be due to model differences and hyperparameter tuning issues.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ReliabilityGeneralizationLocalityPortability
QuestionsWho is the current President of the United States?Who is the head of state in the United States right now?Where is the capital of the United States?Where is the current U.S. President born?
First Edit: Joe Biden ——> Donald Trump
AnswerDonald TrumpDonald TrumpWashington, D.C.Queens, New York
WISEDonald TrumpDonald TrumpWashington, D.CQueens, New York
AlphaEditDonald TrumpDonald TrumpWashington, D.CDonald Trump
LoRADonald TrumpDonald TrumpWashington, D.CQueens, New York
PromptDonald TrumpDonald TrumpWashington, D.C.Queens, New York
Second Edit: Joe Biden ——> Donald Trump ——> Joe Biden
AnswerJoe BidenJoe BidenWashington, D.C.Scranton, Pennsylvania
WISEJoe BidenJoe BidenWashington, D.CScranton, Pennsylvania
AlphaEditJoe BidenJoe BidenJoe BidenJoe Biden
LoRAJoe BidenJoe BidenJoe BidenJoe Biden
PromptJoe BidenJoe BidenWashington, D.CScranton, Pennsylvania
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/mnt/8t/xkw/EasyEdit\n" + ] + } + ], + "source": [ + "%cd .." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load the evaluation function " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "from easyeditor import BaseEditor\n", + "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", + " 'Who is the head of state in the United States right now?',\n", + " \"Where is the capital of the United States?\" ,\n", + " 'Where is the current U.S. President born ?']\n", + "Evaluation_metrics = [\"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n", + "\n", + "def evaluate(model, Evaluation_prompts,Evaluation_metrics, device=1):\n", + " device = f\"cuda:{device}\"\n", + " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3-8b')\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + " tokenizer.padding_side='left'\n", + " \n", + "\n", + " for i in range(len(Evaluation_prompts)):\n", + " \n", + " inputs = [f\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n \\\n", + " You are a helpful assistant. Please answer the questions to the best of your ability.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n \\\n", + " {Evaluation_prompts[i]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\"]\n", + " # inputs = [f\"You are a helpful assistant.Please answer the folling question: {Evaluation_prompts[i]}\"]\n", + " input_ids = tokenizer(inputs, return_tensors=\"pt\").to(device)\n", + " outputs = model.generate(\n", + " input_ids=input_ids['input_ids'],\n", + " attention_mask=input_ids['attention_mask'],\n", + " max_new_tokens=20,\n", + " pad_token_id= tokenizer.eos_token_id,\n", + " do_sample=False,\n", + " use_cache=False\n", + " )\n", + " response = [tokenizer.decode(x[input_ids['input_ids'].shape[-1]: ]) for x in outputs.detach().cpu().numpy().tolist()]\n", + " # response = outputs[0].detach().cpu().numpy().tolist()[input_ids['input_ids'][-1]:]\n", + " # response = tokenizer.decode(response, skip_special_tokens=True)\n", + "\n", + " print(f\"{Evaluation_metrics[i]}: {response}\")\n", + "\n", + "def evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=1):\n", + " device = f\"cuda:{device}\"\n", + " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3.1-8b-instruct')\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + " tokenizer.padding_side='left'\n", + "\n", + " for i in range(len(Evaluation_prompts)):\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": Evaluation_prompts[i] },\n", + " ]\n", + " input_ids = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " return_tensors=\"pt\"\n", + " ).to(model.device)\n", + "\n", + " terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n", + " outputs = model.generate(\n", + " input_ids = input_ids,\n", + " max_new_tokens=256,\n", + " eos_token_id=terminators,\n", + " pad_token_id= tokenizer.eos_token_id\n", + " )\n", + " response = outputs[0][input_ids.shape[-1]:]\n", + " response = tokenizer.decode(response, skip_special_tokens=True)\n", + "\n", + " print(f\"{Evaluation_metrics[i]}: {response}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Orignal Output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test the output of the initial model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007864236831665039, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "5ec93b35b8f1410c8365ead122009f97", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 Donald Trump`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Edit data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "## edit once\n", + "## Joe Biden ——> Donald Trump\n", + "prompts = [\"Who is the current President of the United States?\" ]\n", + "subject = ['President']\n", + "ground_truth = ['Joe Biden']\n", + "target_new = ['Donald Trump']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### WISE" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-10 12:39:07,415 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/10/2024 12:39:07 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.004432201385498047, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "2208433569a64f2fac21a90299bd4abe", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [Donald Trump]\n", + "loss 25.854 = 5.854 + 20.0\n", + "loss 9.917 = 4.367 + 5.551\n", + "loss 4.498 = 0.03 + 4.468\n", + "loss 2.664 = 0.0 + 2.664\n", + "loss 0.916 = 0.0 + 0.915\n", + "loss 0.0 = 0.0 + 0.0\n", + "loss 0.0 = 0.0 + 0.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:10<00:00, 10.81s/it]\n", + "2024-11-10 12:39:27,315 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/10/2024 12:39:27 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n", + "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + ] + } + ], + "source": [ + "from easyeditor import WISEHyperParams\n", + "import json\n", + "\n", + "loc_data = json.load(open('./data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:1]\n", + "loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]\n", + "\n", + "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3.1-8b.yaml')\n", + "hparams.device = 1\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " \n", + " loc_prompts=loc_prompts,\n", + " sequential_edit=True,\n", + ")\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability: The current President Trump left office in January Trump was replaced Donald Trump was replaced Donald Trump was replaced Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", + "Generalization: As Trump is no longer in office Donald Trump's successor Donald Trump was Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", + "Locality: The capital of the United Donald Trump's presidential term, the capital of the United States is Washington, D Donald Trump was not the first president, Donald Trump's presidential term was the 45th, Donald Trump was the 45 Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", + "Portability: The current U Trump was born in Queens, New York. Donald Trump was the 45 Trump Trump Trump Trump Trump Donald Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Donald Trump Trump Donald Trump Trump Trump Donald Trump Trump Trump Trump Donald Trump Trump Trump Trump Donald Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### AlphaEdit" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-10 12:42:35,976 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/10/2024 12:42:35 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007699728012084961, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea5dbd56592b47528a59430738e119f5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [ Donald Trump]\n", + "Cached context templates [['{}'], ['The 2018-19 NBA season is. {}', 'Therefore, we will not discuss the details of. {}', 'Because the number of people living with diabetes continues. {}', 'I have always been interested in the history of. {}', 'You may also wish to search for items by. {}']]\n", + "Computing right vector (v)\n", + "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token: President\n", + "Rewrite layer is 8\n", + "Tying optimization objective to 31\n", + "Recording initial value of v*\n", + "loss 4.299 = 4.299 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.09041433036327362\n", + "loss 3.399 = 3.396 + 0.001 + 0.002 avg prob of [ Donald Trump] 0.21825557947158813\n", + "loss 2.767 = 2.761 + 0.003 + 0.002 avg prob of [ Donald Trump] 0.4464951753616333\n", + "loss 2.353 = 2.35 + 0.0 + 0.003 avg prob of [ Donald Trump] 0.7211445569992065\n", + "loss 2.262 = 2.258 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8048359751701355\n", + "loss 2.242 = 2.237 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8246363401412964\n", + "loss 2.242 = 2.237 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8255119323730469\n", + "loss 2.237 = 2.231 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8306096792221069\n", + "loss 2.235 = 2.23 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8323876261711121\n", + "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8328641653060913\n", + "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8330349326133728\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.833112359046936\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331531286239624\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331767916679382\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331915736198425\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332012891769409\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332079648971558\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332127332687378\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332163691520691\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332261443138123\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332351446151733\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332424163818359\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332486152648926\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332538604736328\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332585096359253\n", + "Init norm 46.17322540283203 | Delta norm 34.629920959472656 | Target norm 58.87141799926758\n", + "\n", + "\n", + "LAYER 4\n", + "\n", + "Writing 1 key/value pair(s) into layer 4\n", + "z error tensor(58.9720, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.5992, device='cuda:1')\n", + "upd norm tensor(1.2758, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 5\n", + "\n", + "Writing 1 key/value pair(s) into layer 5\n", + "z error tensor(57.4963, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.9419, device='cuda:1')\n", + "upd norm tensor(1.8219, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 6\n", + "\n", + "Writing 1 key/value pair(s) into layer 6\n", + "z error tensor(53.5156, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.9026, device='cuda:1')\n", + "upd norm tensor(2.2746, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 7\n", + "\n", + "Writing 1 key/value pair(s) into layer 7\n", + "z error tensor(48.2179, device='cuda:1', grad_fn=)\n", + "orig norm tensor(79.0248, device='cuda:1')\n", + "upd norm tensor(2.9401, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 8\n", + "\n", + "Writing 1 key/value pair(s) into layer 8\n", + "z error tensor(42.1711, device='cuda:1', grad_fn=)\n", + "orig norm tensor(78.7670, device='cuda:1')\n", + "upd norm tensor(5.2222, device='cuda:1', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:22<00:00, 22.73s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "2024-11-10 12:43:08,192 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/10/2024 12:43:08 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metrics Summary: {'pre': {'rewrite_acc': 1.0}, 'post': {'rewrite_acc': 0.0}}\n", + "[{'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}]\n" + ] + } + ], + "source": [ + "from easyeditor import AlphaEditHyperParams\n", + "hparams=AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n", + "hparams.device = 1\n", + "editor=BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " \n", + " sequential_edit=True\n", + ")\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability: ['\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump']\n", + "Generalization: ['\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump']\n", + "Locality: ['\\n Where is the capital ofthe United States?Washington, D.C.\\n Where is the']\n", + "Portability: ['\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump']\n" + ] + } + ], + "source": [ + "evaluate(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### LoRA" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-07 16:07:11,058 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/07/2024 16:07:11 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007943391799926758, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 2, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "aec251c5c46147e2b5464f4902086d4c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 [Donald Trump]\n", + "====================\n", + "Epoch: 0\n", + "====================\n", + "Batch loss 4.321556568145752\n", + "Total loss 4.321556568145752\n", + "====================\n", + "Epoch: 1\n", + "====================\n", + "Batch loss 2.935058116912842\n", + "Total loss 2.935058116912842\n", + "====================\n", + "Epoch: 2\n", + "====================\n", + "Batch loss 0.5914953947067261\n", + "Total loss 0.5914953947067261\n", + "====================\n", + "Epoch: 3\n", + "====================\n", + "Batch loss 0.4943681061267853\n", + "Total loss 0.4943681061267853\n", + "====================\n", + "Epoch: 4\n", + "====================\n", + "Batch loss 0.5276131629943848\n", + "Total loss 0.5276131629943848\n", + "====================\n", + "Epoch: 5\n", + "====================\n", + "Batch loss 0.5325935482978821\n", + "Total loss 0.5325935482978821\n", + "====================\n", + "Epoch: 6\n", + "====================\n", + "Batch loss 0.49563464522361755\n", + "Total loss 0.49563464522361755\n", + "====================\n", + "Epoch: 7\n", + "====================\n", + "Batch loss 0.4502723813056946\n", + "Total loss 0.4502723813056946\n", + "====================\n", + "Epoch: 8\n", + "====================\n", + "Batch loss 0.40860509872436523\n", + "Total loss 0.40860509872436523\n", + "====================\n", + "Epoch: 9\n", + "====================\n", + "Batch loss 0.36605969071388245\n", + "Total loss 0.36605969071388245\n", + "====================\n", + "Epoch: 10\n", + "====================\n", + "Batch loss 0.32609912753105164\n", + "Total loss 0.32609912753105164\n", + "====================\n", + "Epoch: 11\n", + "====================\n", + "Batch loss 0.289233535528183\n", + "Total loss 0.289233535528183\n", + "====================\n", + "Epoch: 12\n", + "====================\n", + "Batch loss 0.2593093514442444\n", + "Total loss 0.2593093514442444\n", + "====================\n", + "Epoch: 13\n", + "====================\n", + "Batch loss 0.23281975090503693\n", + "Total loss 0.23281975090503693\n", + "====================\n", + "Epoch: 14\n", + "====================\n", + "Batch loss 0.20349939167499542\n", + "Total loss 0.20349939167499542\n", + "====================\n", + "Epoch: 15\n", + "====================\n", + "Batch loss 0.17900611460208893\n", + "Total loss 0.17900611460208893\n", + "====================\n", + "Epoch: 16\n", + "====================\n", + "Batch loss 0.15917769074440002\n", + "Total loss 0.15917769074440002\n", + "====================\n", + "Epoch: 17\n", + "====================\n", + "Batch loss 0.14373129606246948\n", + "Total loss 0.14373129606246948\n", + "====================\n", + "Epoch: 18\n", + "====================\n", + "Batch loss 0.13142257928848267\n", + "Total loss 0.13142257928848267\n", + "====================\n", + "Epoch: 19\n", + "====================\n", + "Batch loss 0.11883941292762756\n", + "Total loss 0.11883941292762756\n", + "====================\n", + "Epoch: 20\n", + "====================\n", + "Batch loss 0.11019845306873322\n", + "Total loss 0.11019845306873322\n", + "====================\n", + "Epoch: 21\n", + "====================\n", + "Batch loss 0.10458292067050934\n", + "Total loss 0.10458292067050934\n", + "====================\n", + "Epoch: 22\n", + "====================\n", + "Batch loss 0.09492404013872147\n", + "Total loss 0.09492404013872147\n", + "====================\n", + "Epoch: 23\n", + "====================\n", + "Batch loss 0.0870729312300682\n", + "Total loss 0.0870729312300682\n", + "====================\n", + "Epoch: 24\n", + "====================\n", + "Batch loss 0.08295147866010666\n", + "Total loss 0.08295147866010666\n", + "====================\n", + "Epoch: 25\n", + "====================\n", + "Batch loss 0.07843763381242752\n", + "Total loss 0.07843763381242752\n", + "====================\n", + "Epoch: 26\n", + "====================\n", + "Batch loss 0.07287432253360748\n", + "Total loss 0.07287432253360748\n", + "====================\n", + "Epoch: 27\n", + "====================\n", + "Batch loss 0.06879541277885437\n", + "Total loss 0.06879541277885437\n", + "====================\n", + "Epoch: 28\n", + "====================\n", + "Batch loss 0.06624111533164978\n", + "Total loss 0.06624111533164978\n", + "====================\n", + "Epoch: 29\n", + "====================\n", + "Batch loss 0.062170736491680145\n", + "Total loss 0.062170736491680145\n", + "====================\n", + "Epoch: 30\n", + "====================\n", + "Batch loss 0.05983356758952141\n", + "Total loss 0.05983356758952141\n", + "====================\n", + "Epoch: 31\n", + "====================\n", + "Batch loss 0.05892789736390114\n", + "Total loss 0.05892789736390114\n", + "====================\n", + "Epoch: 32\n", + "====================\n", + "Batch loss 0.05899837985634804\n", + "Total loss 0.05899837985634804\n", + "====================\n", + "Epoch: 33\n", + "====================\n", + "Batch loss 0.05542633682489395\n", + "Total loss 0.05542633682489395\n", + "====================\n", + "Epoch: 34\n", + "====================\n", + "Batch loss 0.052374809980392456\n", + "Total loss 0.052374809980392456\n", + "====================\n", + "Epoch: 35\n", + "====================\n", + "Batch loss 0.05026000738143921\n", + "Total loss 0.05026000738143921\n", + "====================\n", + "Epoch: 36\n", + "====================\n", + "Batch loss 0.04752592742443085\n", + "Total loss 0.04752592742443085\n", + "====================\n", + "Epoch: 37\n", + "====================\n", + "Batch loss 0.04687608778476715\n", + "Total loss 0.04687608778476715\n", + "====================\n", + "Epoch: 38\n", + "====================\n", + "Batch loss 0.04480736330151558\n", + "Total loss 0.04480736330151558\n", + "====================\n", + "Epoch: 39\n", + "====================\n", + "Batch loss 0.04410180449485779\n", + "Total loss 0.04410180449485779\n", + "====================\n", + "Epoch: 40\n", + "====================\n", + "Batch loss 0.040795814245939255\n", + "Total loss 0.040795814245939255\n", + "====================\n", + "Epoch: 41\n", + "====================\n", + "Batch loss 0.042223721742630005\n", + "Total loss 0.042223721742630005\n", + "====================\n", + "Epoch: 42\n", + "====================\n", + "Batch loss 0.03964998573064804\n", + "Total loss 0.03964998573064804\n", + "====================\n", + "Epoch: 43\n", + "====================\n", + "Batch loss 0.040437471121549606\n", + "Total loss 0.040437471121549606\n", + "====================\n", + "Epoch: 44\n", + "====================\n", + "Batch loss 0.03891940787434578\n", + "Total loss 0.03891940787434578\n", + "====================\n", + "Epoch: 45\n", + "====================\n", + "Batch loss 0.038434166461229324\n", + "Total loss 0.038434166461229324\n", + "====================\n", + "Epoch: 46\n", + "====================\n", + "Batch loss 0.03714253753423691\n", + "Total loss 0.03714253753423691\n", + "====================\n", + "Epoch: 47\n", + "====================\n", + "Batch loss 0.03693273290991783\n", + "Total loss 0.03693273290991783\n", + "====================\n", + "Epoch: 48\n", + "====================\n", + "Batch loss 0.03618853911757469\n", + "Total loss 0.03618853911757469\n", + "====================\n", + "Epoch: 49\n", + "====================\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:09<00:00, 9.89s/it]\n", + "2024-11-07 16:07:26,595 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/07/2024 16:07:26 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch loss 0.03574013337492943\n", + "Total loss 0.03574013337492943\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n", + "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + ] + } + ], + "source": [ + "from easyeditor import LoRAHyperParams\n", + "hparams=LoRAHyperParams.from_hparams('./hparams/LoRA/llama3.1-8b.yaml')\n", + "hparams.device = 1\n", + "editor=BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " \n", + " sequential_edit=True\n", + ")\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability: Donald Trump was the President of the United States from 2017 to 2021. Donald Trump was succeeded by Donald Trump's successor Donald Trump's successor Donald Trump's successor Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald\n", + "Generalization: Donald Trump was the President of the United States from 2017 to 2021, Donald Trump was succeeded by Donald Trump's Vice President, Donald Trump's Vice President Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", + "Locality: The capital of the United States is Washington, D.C. (District of Columbia).\n", + "Portability: Donald Trump, the 45th U.S. President, was born in Queens, New York City, on June 14, 1946.\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.005389690399169922, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "125da1af9e3446eb83fd1284daa38604", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 Donald Trump —> Joe Biden" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Edit data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "## Joe Biden —> Donald Trump —> Joe Biden\n", + "prompts = [\"Who is the current President of the United States?\",\n", + " \"Who is the current President of the United States?\" ]\n", + "subject = ['President', 'President']\n", + "ground_truth = ['Joe Biden', 'Donald Trump']\n", + "target_new = ['Donald Trump', 'Joe Biden']\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### WISE" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-10 12:48:58,607 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/10/2024 12:48:58 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.008575677871704102, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "db3e4cefd2ac4d8f80007dbff3915a7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [Donald Trump]\n", + "loss 25.854 = 5.854 + 20.0\n", + "loss 9.917 = 4.367 + 5.551\n", + "loss 4.498 = 0.03 + 4.468\n", + "loss 2.664 = 0.0 + 2.664\n", + "loss 0.916 = 0.0 + 0.915\n", + "loss 0.0 = 0.0 + 0.0\n", + "loss 0.0 = 0.0 + 0.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 1/2 [00:10<00:10, 10.77s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Executing WISE algorithm for the update: \n", + "[Who is the current President of the United States?] -> [Joe Biden]\n", + "loss 15.073 = 15.073 + 0.0\n", + "loss 0.756 = 0.009 + 0.747\n", + "loss 0.003 = 0.003 + 0.0\n", + "loss 0.003 = 0.003 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n", + "loss 0.002 = 0.002 + 0.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:21<00:00, 10.96s/it]\n", + "2024-11-10 12:49:29,965 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "11/10/2024 12:49:29 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "2024-11-10 12:49:30,075 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/10/2024 12:49:30 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 0.75}}\n", + "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + ] + } + ], + "source": [ + "from easyeditor import WISEHyperParams\n", + "import json\n", + "\n", + "loc_data = json.load(open('./data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:2]\n", + "loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]\n", + "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3.1-8b.yaml')\n", + "hparams.device=1\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + " \n", + " loc_prompts=loc_prompts,\n", + " sequential_edit=True,\n", + ")\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability: The current President Joe Biden has been in office since Joe Biden was inaugurated on Joe Biden's 78th birthday January 20, Joe Biden Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe\n", + "Generalization: As Joe Biden is the current President of Joe Biden is the President of the United Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe\n", + "Locality: The capital of the United States is Washington, D.C. (short for District of Columbia).\n", + "Portability: The current U Joe Biden U.S. President Joe Biden was born in Scranton, Pennsylvania Joe Biden was born in Scr Joe Biden Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### AlphaEdit" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-10 12:52:34,953 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/10/2024 12:52:34 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007628440856933594, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "782b6df90d9f45c6ad938911e8dd4d04", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 [ Donald Trump]\n", + "Cached context templates [['{}'], ['The 2018-19 NBA season is. {}', 'Therefore, we will not discuss the details of. {}', 'Because the number of people living with diabetes continues. {}', 'I have always been interested in the history of. {}', 'You may also wish to search for items by. {}']]\n", + "Computing right vector (v)\n", + "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token: President\n", + "Rewrite layer is 8\n", + "Tying optimization objective to 31\n", + "Recording initial value of v*\n", + "loss 4.299 = 4.299 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.09041433036327362\n", + "loss 3.399 = 3.396 + 0.001 + 0.002 avg prob of [ Donald Trump] 0.21825557947158813\n", + "loss 2.767 = 2.761 + 0.003 + 0.002 avg prob of [ Donald Trump] 0.4464951753616333\n", + "loss 2.353 = 2.35 + 0.0 + 0.003 avg prob of [ Donald Trump] 0.7211445569992065\n", + "loss 2.262 = 2.258 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8048359751701355\n", + "loss 2.242 = 2.237 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8246363401412964\n", + "loss 2.242 = 2.237 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8255119323730469\n", + "loss 2.237 = 2.231 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8306096792221069\n", + "loss 2.235 = 2.23 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8323876261711121\n", + "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8328641653060913\n", + "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8330349326133728\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.833112359046936\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331531286239624\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331767916679382\n", + "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331915736198425\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332012891769409\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332079648971558\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332127332687378\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332163691520691\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332261443138123\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332351446151733\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332424163818359\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332486152648926\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332538604736328\n", + "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332585096359253\n", + "Init norm 46.17322540283203 | Delta norm 34.629920959472656 | Target norm 58.87141799926758\n", + "\n", + "\n", + "LAYER 4\n", + "\n", + "Writing 1 key/value pair(s) into layer 4\n", + "z error tensor(58.9720, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.5992, device='cuda:1')\n", + "upd norm tensor(1.2758, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 5\n", + "\n", + "Writing 1 key/value pair(s) into layer 5\n", + "z error tensor(57.4963, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.9419, device='cuda:1')\n", + "upd norm tensor(1.8219, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 6\n", + "\n", + "Writing 1 key/value pair(s) into layer 6\n", + "z error tensor(53.5156, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.9026, device='cuda:1')\n", + "upd norm tensor(2.2746, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 7\n", + "\n", + "Writing 1 key/value pair(s) into layer 7\n", + "z error tensor(48.2179, device='cuda:1', grad_fn=)\n", + "orig norm tensor(79.0248, device='cuda:1')\n", + "upd norm tensor(2.9401, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 8\n", + "\n", + "Writing 1 key/value pair(s) into layer 8\n", + "z error tensor(42.1711, device='cuda:1', grad_fn=)\n", + "orig norm tensor(78.7670, device='cuda:1')\n", + "upd norm tensor(5.2222, device='cuda:1', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 1/2 [00:22<00:22, 22.61s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "Executing AlphaEdit algo for: [Who is the current {} of the United States?] -> [ Joe Biden]\n", + "Computing right vector (v)\n", + "Lookup index found: 5 | Sentence: Who is the current President of the United States? Joe | Token: President\n", + "Rewrite layer is 8\n", + "Tying optimization objective to 31\n", + "Recording initial value of v*\n", + "loss 8.092 = 8.092 + 0.0 + 0.0 avg prob of [ Joe Biden] 0.0012095430865883827\n", + "loss 6.742 = 6.741 + 0.0 + 0.001 avg prob of [ Joe Biden] 0.00625673308968544\n", + "loss 4.74 = 4.738 + 0.0 + 0.002 avg prob of [ Joe Biden] 0.07003729790449142\n", + "loss 3.583 = 3.581 + 0.0 + 0.002 avg prob of [ Joe Biden] 0.2744549512863159\n", + "loss 3.135 = 3.132 + 0.0 + 0.003 avg prob of [ Joe Biden] 0.46447789669036865\n", + "loss 3.057 = 3.054 + 0.0 + 0.003 avg prob of [ Joe Biden] 0.5108730792999268\n", + "loss 2.996 = 2.992 + 0.001 + 0.004 avg prob of [ Joe Biden] 0.5502802133560181\n", + "loss 2.973 = 2.962 + 0.008 + 0.004 avg prob of [ Joe Biden] 0.5703158378601074\n", + "loss 2.968 = 2.964 + 0.0 + 0.004 avg prob of [ Joe Biden] 0.5694985389709473\n", + "loss 2.908 = 2.903 + 0.0 + 0.005 avg prob of [ Joe Biden] 0.6119259595870972\n", + "loss 2.894 = 2.889 + 0.0 + 0.005 avg prob of [ Joe Biden] 0.623217761516571\n", + "loss 2.85 = 2.845 + 0.0 + 0.005 avg prob of [ Joe Biden] 0.6566493511199951\n", + "loss 2.835 = 2.829 + 0.001 + 0.005 avg prob of [ Joe Biden] 0.6692723631858826\n", + "loss 2.809 = 2.798 + 0.005 + 0.005 avg prob of [ Joe Biden] 0.6948184967041016\n", + "loss 2.775 = 2.768 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.7195937633514404\n", + "loss 2.761 = 2.755 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7305320501327515\n", + "loss 2.734 = 2.728 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7539100050926208\n", + "loss 2.731 = 2.725 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7571267485618591\n", + "loss 2.714 = 2.708 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.772413432598114\n", + "loss 2.712 = 2.705 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7749055624008179\n", + "loss 2.697 = 2.69 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.788467288017273\n", + "loss 2.691 = 2.684 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.7946388125419617\n", + "loss 2.682 = 2.675 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.803372859954834\n", + "loss 2.679 = 2.671 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.8066705465316772\n", + "loss 2.675 = 2.668 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.8101906776428223\n", + "Init norm 58.270999908447266 | Delta norm 43.70325469970703 | Target norm 71.81600189208984\n", + "\n", + "\n", + "LAYER 4\n", + "\n", + "Writing 1 key/value pair(s) into layer 4\n", + "z error tensor(56.2028, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.6094, device='cuda:1')\n", + "upd norm tensor(1.0710, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 5\n", + "\n", + "Writing 1 key/value pair(s) into layer 5\n", + "z error tensor(55.3878, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.9589, device='cuda:1')\n", + "upd norm tensor(1.4956, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 6\n", + "\n", + "Writing 1 key/value pair(s) into layer 6\n", + "z error tensor(53.6229, device='cuda:1', grad_fn=)\n", + "orig norm tensor(77.9275, device='cuda:1')\n", + "upd norm tensor(1.8921, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 7\n", + "\n", + "Writing 1 key/value pair(s) into layer 7\n", + "z error tensor(51.1769, device='cuda:1', grad_fn=)\n", + "orig norm tensor(79.0695, device='cuda:1')\n", + "upd norm tensor(2.6346, device='cuda:1', grad_fn=)\n", + "\n", + "\n", + "LAYER 8\n", + "\n", + "Writing 1 key/value pair(s) into layer 8\n", + "z error tensor(47.3937, device='cuda:1', grad_fn=)\n", + "orig norm tensor(78.9209, device='cuda:1')\n", + "upd norm tensor(4.8986, device='cuda:1', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:41<00:00, 20.82s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n", + "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "2024-11-10 12:53:26,318 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/10/2024 12:53:26 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "2024-11-10 12:53:26,391 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/10/2024 12:53:26 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Metrics Summary: {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.0}}\n", + "[{'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}]\n" + ] + } + ], + "source": [ + "from easyeditor import AlphaEditHyperParams\n", + "\n", + "hparams = AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n", + "hparams.device = 1\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + "\n", + " sequential_edit=True,\n", + ")\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability: [' Biden:// Biden:// Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n", + "Generalization: [' Biden:// Bidenating Biden Joeating Biden:// Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n", + "Locality: [' Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n", + "Portability: [' Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n" + ] + } + ], + "source": [ + "evaluate(edited_model,Evaluation_prompts, Evaluation_metrics, device=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### LoRA" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-07 16:11:12,868 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/07/2024 16:11:12 - INFO - easyeditor.editors.editor - Instantiating model\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007766246795654297, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 2, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "dbe7a3ce8d324344bad26759c2e0348b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 [Donald Trump]\n", + "====================\n", + "Epoch: 0\n", + "====================\n", + "Batch loss 4.321556568145752\n", + "Total loss 4.321556568145752\n", + "====================\n", + "Epoch: 1\n", + "====================\n", + "Batch loss 2.935058116912842\n", + "Total loss 2.935058116912842\n", + "====================\n", + "Epoch: 2\n", + "====================\n", + "Batch loss 0.5914953947067261\n", + "Total loss 0.5914953947067261\n", + "====================\n", + "Epoch: 3\n", + "====================\n", + "Batch loss 0.4943681061267853\n", + "Total loss 0.4943681061267853\n", + "====================\n", + "Epoch: 4\n", + "====================\n", + "Batch loss 0.5276131629943848\n", + "Total loss 0.5276131629943848\n", + "====================\n", + "Epoch: 5\n", + "====================\n", + "Batch loss 0.5325935482978821\n", + "Total loss 0.5325935482978821\n", + "====================\n", + "Epoch: 6\n", + "====================\n", + "Batch loss 0.49563464522361755\n", + "Total loss 0.49563464522361755\n", + "====================\n", + "Epoch: 7\n", + "====================\n", + "Batch loss 0.4502723813056946\n", + "Total loss 0.4502723813056946\n", + "====================\n", + "Epoch: 8\n", + "====================\n", + "Batch loss 0.40860509872436523\n", + "Total loss 0.40860509872436523\n", + "====================\n", + "Epoch: 9\n", + "====================\n", + "Batch loss 0.36605969071388245\n", + "Total loss 0.36605969071388245\n", + "====================\n", + "Epoch: 10\n", + "====================\n", + "Batch loss 0.32609912753105164\n", + "Total loss 0.32609912753105164\n", + "====================\n", + "Epoch: 11\n", + "====================\n", + "Batch loss 0.289233535528183\n", + "Total loss 0.289233535528183\n", + "====================\n", + "Epoch: 12\n", + "====================\n", + "Batch loss 0.2593093514442444\n", + "Total loss 0.2593093514442444\n", + "====================\n", + "Epoch: 13\n", + "====================\n", + "Batch loss 0.23281975090503693\n", + "Total loss 0.23281975090503693\n", + "====================\n", + "Epoch: 14\n", + "====================\n", + "Batch loss 0.20349939167499542\n", + "Total loss 0.20349939167499542\n", + "====================\n", + "Epoch: 15\n", + "====================\n", + "Batch loss 0.17900611460208893\n", + "Total loss 0.17900611460208893\n", + "====================\n", + "Epoch: 16\n", + "====================\n", + "Batch loss 0.15917769074440002\n", + "Total loss 0.15917769074440002\n", + "====================\n", + "Epoch: 17\n", + "====================\n", + "Batch loss 0.14373129606246948\n", + "Total loss 0.14373129606246948\n", + "====================\n", + "Epoch: 18\n", + "====================\n", + "Batch loss 0.13142257928848267\n", + "Total loss 0.13142257928848267\n", + "====================\n", + "Epoch: 19\n", + "====================\n", + "Batch loss 0.11883941292762756\n", + "Total loss 0.11883941292762756\n", + "====================\n", + "Epoch: 20\n", + "====================\n", + "Batch loss 0.11019845306873322\n", + "Total loss 0.11019845306873322\n", + "====================\n", + "Epoch: 21\n", + "====================\n", + "Batch loss 0.10458292067050934\n", + "Total loss 0.10458292067050934\n", + "====================\n", + "Epoch: 22\n", + "====================\n", + "Batch loss 0.09492404013872147\n", + "Total loss 0.09492404013872147\n", + "====================\n", + "Epoch: 23\n", + "====================\n", + "Batch loss 0.0870729312300682\n", + "Total loss 0.0870729312300682\n", + "====================\n", + "Epoch: 24\n", + "====================\n", + "Batch loss 0.08295147866010666\n", + "Total loss 0.08295147866010666\n", + "====================\n", + "Epoch: 25\n", + "====================\n", + "Batch loss 0.07843763381242752\n", + "Total loss 0.07843763381242752\n", + "====================\n", + "Epoch: 26\n", + "====================\n", + "Batch loss 0.07287432253360748\n", + "Total loss 0.07287432253360748\n", + "====================\n", + "Epoch: 27\n", + "====================\n", + "Batch loss 0.06879541277885437\n", + "Total loss 0.06879541277885437\n", + "====================\n", + "Epoch: 28\n", + "====================\n", + "Batch loss 0.06624111533164978\n", + "Total loss 0.06624111533164978\n", + "====================\n", + "Epoch: 29\n", + "====================\n", + "Batch loss 0.062170736491680145\n", + "Total loss 0.062170736491680145\n", + "====================\n", + "Epoch: 30\n", + "====================\n", + "Batch loss 0.05983356758952141\n", + "Total loss 0.05983356758952141\n", + "====================\n", + "Epoch: 31\n", + "====================\n", + "Batch loss 0.05892789736390114\n", + "Total loss 0.05892789736390114\n", + "====================\n", + "Epoch: 32\n", + "====================\n", + "Batch loss 0.05899837985634804\n", + "Total loss 0.05899837985634804\n", + "====================\n", + "Epoch: 33\n", + "====================\n", + "Batch loss 0.05542633682489395\n", + "Total loss 0.05542633682489395\n", + "====================\n", + "Epoch: 34\n", + "====================\n", + "Batch loss 0.052374809980392456\n", + "Total loss 0.052374809980392456\n", + "====================\n", + "Epoch: 35\n", + "====================\n", + "Batch loss 0.05026000738143921\n", + "Total loss 0.05026000738143921\n", + "====================\n", + "Epoch: 36\n", + "====================\n", + "Batch loss 0.04752592742443085\n", + "Total loss 0.04752592742443085\n", + "====================\n", + "Epoch: 37\n", + "====================\n", + "Batch loss 0.04687608778476715\n", + "Total loss 0.04687608778476715\n", + "====================\n", + "Epoch: 38\n", + "====================\n", + "Batch loss 0.04480736330151558\n", + "Total loss 0.04480736330151558\n", + "====================\n", + "Epoch: 39\n", + "====================\n", + "Batch loss 0.04410180449485779\n", + "Total loss 0.04410180449485779\n", + "====================\n", + "Epoch: 40\n", + "====================\n", + "Batch loss 0.040795814245939255\n", + "Total loss 0.040795814245939255\n", + "====================\n", + "Epoch: 41\n", + "====================\n", + "Batch loss 0.042223721742630005\n", + "Total loss 0.042223721742630005\n", + "====================\n", + "Epoch: 42\n", + "====================\n", + "Batch loss 0.03964998573064804\n", + "Total loss 0.03964998573064804\n", + "====================\n", + "Epoch: 43\n", + "====================\n", + "Batch loss 0.040437471121549606\n", + "Total loss 0.040437471121549606\n", + "====================\n", + "Epoch: 44\n", + "====================\n", + "Batch loss 0.03891940787434578\n", + "Total loss 0.03891940787434578\n", + "====================\n", + "Epoch: 45\n", + "====================\n", + "Batch loss 0.038434166461229324\n", + "Total loss 0.038434166461229324\n", + "====================\n", + "Epoch: 46\n", + "====================\n", + "Batch loss 0.03714253753423691\n", + "Total loss 0.03714253753423691\n", + "====================\n", + "Epoch: 47\n", + "====================\n", + "Batch loss 0.03693273290991783\n", + "Total loss 0.03693273290991783\n", + "====================\n", + "Epoch: 48\n", + "====================\n", + "Batch loss 0.03618853911757469\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 1/2 [00:09<00:09, 9.98s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total loss 0.03618853911757469\n", + "====================\n", + "Epoch: 49\n", + "====================\n", + "Batch loss 0.03574013337492943\n", + "Total loss 0.03574013337492943\n", + "Executing LoRA algo for: [Who is the current President of the United States?] -> [Joe Biden]\n", + "====================\n", + "Epoch: 0\n", + "====================\n", + "Batch loss 11.633792877197266\n", + "Total loss 11.633792877197266\n", + "====================\n", + "Epoch: 1\n", + "====================\n", + "Batch loss 1.4652514457702637\n", + "Total loss 1.4652514457702637\n", + "====================\n", + "Epoch: 2\n", + "====================\n", + "Batch loss 0.0561312697827816\n", + "Total loss 0.0561312697827816\n", + "====================\n", + "Epoch: 3\n", + "====================\n", + "Batch loss 0.0029083709232509136\n", + "Total loss 0.0029083709232509136\n", + "====================\n", + "Epoch: 4\n", + "====================\n", + "Batch loss 0.000780549249611795\n", + "Total loss 0.000780549249611795\n", + "====================\n", + "Epoch: 5\n", + "====================\n", + "Batch loss 0.0003094780258834362\n", + "Total loss 0.0003094780258834362\n", + "====================\n", + "Epoch: 6\n", + "====================\n", + "Batch loss 0.0003064991033170372\n", + "Total loss 0.0003064991033170372\n", + "====================\n", + "Epoch: 7\n", + "====================\n", + "Batch loss 0.0005821007653139532\n", + "Total loss 0.0005821007653139532\n", + "====================\n", + "Epoch: 8\n", + "====================\n", + "Batch loss 0.0009712819592095912\n", + "Total loss 0.0009712819592095912\n", + "====================\n", + "Epoch: 9\n", + "====================\n", + "Batch loss 0.0013521392829716206\n", + "Total loss 0.0013521392829716206\n", + "====================\n", + "Epoch: 10\n", + "====================\n", + "Batch loss 0.0008831200539134443\n", + "Total loss 0.0008831200539134443\n", + "====================\n", + "Epoch: 11\n", + "====================\n", + "Batch loss 0.00034314411459490657\n", + "Total loss 0.00034314411459490657\n", + "====================\n", + "Epoch: 12\n", + "====================\n", + "Batch loss 0.0001988188741961494\n", + "Total loss 0.0001988188741961494\n", + "====================\n", + "Epoch: 13\n", + "====================\n", + "Batch loss 0.00020840521028731018\n", + "Total loss 0.00020840521028731018\n", + "====================\n", + "Epoch: 14\n", + "====================\n", + "Batch loss 0.00011371583968866616\n", + "Total loss 0.00011371583968866616\n", + "====================\n", + "Epoch: 15\n", + "====================\n", + "Batch loss 7.82558781793341e-05\n", + "Total loss 7.82558781793341e-05\n", + "====================\n", + "Epoch: 16\n", + "====================\n", + "Batch loss 6.049565854482353e-05\n", + "Total loss 6.049565854482353e-05\n", + "====================\n", + "Epoch: 17\n", + "====================\n", + "Batch loss 4.762218668474816e-05\n", + "Total loss 4.762218668474816e-05\n", + "====================\n", + "Epoch: 18\n", + "====================\n", + "Batch loss 3.903979450115003e-05\n", + "Total loss 3.903979450115003e-05\n", + "====================\n", + "Epoch: 19\n", + "====================\n", + "Batch loss 3.617890615714714e-05\n", + "Total loss 3.617890615714714e-05\n", + "====================\n", + "Epoch: 20\n", + "====================\n", + "Batch loss 2.777510781015735e-05\n", + "Total loss 2.777510781015735e-05\n", + "====================\n", + "Epoch: 21\n", + "====================\n", + "Batch loss 2.6463871108717285e-05\n", + "Total loss 2.6463871108717285e-05\n", + "====================\n", + "Epoch: 22\n", + "====================\n", + "Batch loss 2.294735168106854e-05\n", + "Total loss 2.294735168106854e-05\n", + "====================\n", + "Epoch: 23\n", + "====================\n", + "Batch loss 2.205331838922575e-05\n", + "Total loss 2.205331838922575e-05\n", + "====================\n", + "Epoch: 24\n", + "====================\n", + "Batch loss 2.098047116305679e-05\n", + "Total loss 2.098047116305679e-05\n", + "====================\n", + "Epoch: 25\n", + "====================\n", + "Batch loss 1.990763848880306e-05\n", + "Total loss 1.990763848880306e-05\n", + "====================\n", + "Epoch: 26\n", + "====================\n", + "Batch loss 1.805993633752223e-05\n", + "Total loss 1.805993633752223e-05\n", + "====================\n", + "Epoch: 27\n", + "====================\n", + "Batch loss 1.7106289305957034e-05\n", + "Total loss 1.7106289305957034e-05\n", + "====================\n", + "Epoch: 28\n", + "====================\n", + "Batch loss 1.4781775462324731e-05\n", + "Total loss 1.4781775462324731e-05\n", + "====================\n", + "Epoch: 29\n", + "====================\n", + "Batch loss 1.5914243704173714e-05\n", + "Total loss 1.5914243704173714e-05\n", + "====================\n", + "Epoch: 30\n", + "====================\n", + "Batch loss 1.5258598068612628e-05\n", + "Total loss 1.5258598068612628e-05\n", + "====================\n", + "Epoch: 31\n", + "====================\n", + "Batch loss 1.5199007975752465e-05\n", + "Total loss 1.5199007975752465e-05\n", + "====================\n", + "Epoch: 32\n", + "====================\n", + "Batch loss 1.4781775462324731e-05\n", + "Total loss 1.4781775462324731e-05\n", + "====================\n", + "Epoch: 33\n", + "====================\n", + "Batch loss 1.358971439913148e-05\n", + "Total loss 1.358971439913148e-05\n", + "====================\n", + "Epoch: 34\n", + "====================\n", + "Batch loss 1.281486856896663e-05\n", + "Total loss 1.281486856896663e-05\n", + "====================\n", + "Epoch: 35\n", + "====================\n", + "Batch loss 1.1980425369984005e-05\n", + "Total loss 1.1980425369984005e-05\n", + "====================\n", + "Epoch: 36\n", + "====================\n", + "Batch loss 1.2040026376780588e-05\n", + "Total loss 1.2040026376780588e-05\n", + "====================\n", + "Epoch: 37\n", + "====================\n", + "Batch loss 1.1324782462907024e-05\n", + "Total loss 1.1324782462907024e-05\n", + "====================\n", + "Epoch: 38\n", + "====================\n", + "Batch loss 1.0549935723247472e-05\n", + "Total loss 1.0549935723247472e-05\n", + "====================\n", + "Epoch: 39\n", + "====================\n", + "Batch loss 1.060953854903346e-05\n", + "Total loss 1.060953854903346e-05\n", + "====================\n", + "Epoch: 40\n", + "====================\n", + "Batch loss 1.0013503924710676e-05\n", + "Total loss 1.0013503924710676e-05\n", + "====================\n", + "Epoch: 41\n", + "====================\n", + "Batch loss 1.0192314221058041e-05\n", + "Total loss 1.0192314221058041e-05\n", + "====================\n", + "Epoch: 42\n", + "====================\n", + "Batch loss 9.655880603531841e-06\n", + "Total loss 9.655880603531841e-06\n", + "====================\n", + "Epoch: 43\n", + "====================\n", + "Batch loss 9.417468390893191e-06\n", + "Total loss 9.417468390893191e-06\n", + "====================\n", + "Epoch: 44\n", + "====================\n", + "Batch loss 9.059845069714356e-06\n", + "Total loss 9.059845069714356e-06\n", + "====================\n", + "Epoch: 45\n", + "====================\n", + "Batch loss 9.238655366061721e-06\n", + "Total loss 9.238655366061721e-06\n", + "====================\n", + "Epoch: 46\n", + "====================\n", + "Batch loss 7.748558346065693e-06\n", + "Total loss 7.748558346065693e-06\n", + "====================\n", + "Epoch: 47\n", + "====================\n", + "Batch loss 7.748559255560394e-06\n", + "Total loss 7.748559255560394e-06\n", + "====================\n", + "Epoch: 48\n", + "====================\n", + "Batch loss 7.510143404942937e-06\n", + "Total loss 7.510143404942937e-06\n", + "====================\n", + "Epoch: 49\n", + "====================\n", + "Batch loss 7.212123819044791e-06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2/2 [00:17<00:00, 8.90s/it]\n", + "2024-11-07 16:11:36,185 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/07/2024 16:11:36 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "2024-11-07 16:11:36,248 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/07/2024 16:11:36 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "\n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total loss 7.212123819044791e-06\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 0.5}}\n", + "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + ] + } + ], + "source": [ + "\n", + "from easyeditor import LoRAHyperParams\n", + "\n", + "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3.1-8b.yaml')\n", + "hparams.device = 1\n", + "editor = BaseEditor.from_hparams(hparams)\n", + "metrics, edited_model, _ = editor.edit(\n", + " prompts=prompts,\n", + " ground_truth=ground_truth,\n", + " target_new=target_new,\n", + " subject=subject,\n", + "\n", + " sequential_edit=True,\n", + ")\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reliability: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", + "Generalization: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", + "Locality: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", + "Portability: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n" + ] + } + ], + "source": [ + "evaluate_chat_template(edited_model, Evaluation_prompts, Evaluation_metrics,device=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.009534835815429688, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Loading checkpoint shards", + "rate": null, + "total": 4, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "766aafcbd9894bf3aef6f85821e07bbd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 Date: Mon, 11 Nov 2024 16:50:05 +0800 Subject: [PATCH 2/5] EasyEdit Example with the US President --- hparams/LoRA/llama3-8b.yaml | 18 + hparams/LoRA/llama3.1-8b.yaml | 2 +- hparams/WISE/llama3-8b.yaml | 35 + .../EasyEdit_Example_US_President.ipynb | 1359 ++++++++++------- 4 files changed, 862 insertions(+), 552 deletions(-) create mode 100644 hparams/LoRA/llama3-8b.yaml create mode 100644 hparams/WISE/llama3-8b.yaml diff --git a/hparams/LoRA/llama3-8b.yaml b/hparams/LoRA/llama3-8b.yaml new file mode 100644 index 00000000..a071521f --- /dev/null +++ b/hparams/LoRA/llama3-8b.yaml @@ -0,0 +1,18 @@ +alg_name: "LoRA" +model_name: "./hugging_cache/llama-3-8b-instruct" +device: 1 + +lora_type: "adalora" +layers: [] +num_steps: 70 +batch_size: 1 +max_length: 50 +lr: 5e-3 +weight_decay: 0 +kl_factor: 0 +rank: 8 +lora_alpha: 32 +lora_dropout: 0.1 +norm_constraint: false +target_modules: ["q_proj", "v_proj"] #["up_proj", "down_proj"] #["q_proj", "v_proj"] +model_parallel: false \ No newline at end of file diff --git a/hparams/LoRA/llama3.1-8b.yaml b/hparams/LoRA/llama3.1-8b.yaml index 7fc22cfc..625769ed 100644 --- a/hparams/LoRA/llama3.1-8b.yaml +++ b/hparams/LoRA/llama3.1-8b.yaml @@ -1,5 +1,5 @@ alg_name: "LoRA" -model_name: "./hugging_cache/llama-3.2-3b-instruct" +model_name: "./hugging_cache/llama-3.1-8b-instruct" device: 1 lora_type: "adalora" diff --git a/hparams/WISE/llama3-8b.yaml b/hparams/WISE/llama3-8b.yaml new file mode 100644 index 00000000..91d25376 --- /dev/null +++ b/hparams/WISE/llama3-8b.yaml @@ -0,0 +1,35 @@ +alg_name: "WISE" +model_name: "./hugging_cache/llama-3-8b-instruct" + +device: 0 + +mask_ratio: 0.2 +edit_lr: 0.9 +n_iter: 30 +norm_constraint: 1.0 +act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma +act_ratio: 0.88 +save_freq: 500 +merge_freq: 1000 +merge_alg: 'ties' +objective_optimization: 'only_label' +inner_params: +- model.layers[29].mlp.down_proj.weight + + +## alternative: WISE-Merge, WISE-Retrieve + +# for merge (if merge) +densities: 0.53 +weights: 1.0 + +# for retrieve (if retrieve, pls set to True) +retrieve: True +replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3 + +model_parallel: False +use_chat_template: True + +# for save and load +# save_path: "./wise_checkpoint/wise.pt" +# load_path: "./wise_checkpoint/wise.pt" \ No newline at end of file diff --git a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb index 93c9253f..84c8dc49 100644 --- a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb +++ b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb @@ -30,9 +30,7 @@ "- `Biden → Trump`
\n", "- `Biden → Trump → Biden` (simulating the interesting shift of Trump → Biden → Trump).
\n", "\n", - "In this tutorial, we use `Wise`、`AlphaEdit`、`LoRA`、`Prompt` to implement editing.
\n", - "Specifically, the `Wise`、`LoRA`、`Prompt` methods are used to edit the `Llama3.1-8b` model.
\n", - "As for `AlphaEdit`, due to limitations in computational power and time, we currently only provide the projection matrix P for `Llama3-8B`, specifically for layers [4, 5, 6, 7, 8]." + "In this tutorial, we use `Wise`、`AlphaEdit`、`LoRA`、`Prompt` to edit `Llama3-8B`.
\n" ] }, { @@ -95,9 +93,10 @@ "\n", "The editing results are shown in the table below, with **highlighted** areas indicating that the output **does not match the answer**.
\n", "From the table, it can be seen that:
\n", - "**_WISE and Prompt_** can complete the task well.
\n", + "**_Prompt_** can complete the task well.
\n", + "**_WISE_** encountered Portability issues during the first editing.
\n", "**_LoRA_** is competent for the first editing, but there are exceptions for the second editing in Locality and Portability.
\n", - "**_AlphaEdit_** has problems in both cases for Locality and Portability, we speculate that this may be due to model differences and hyperparameter tuning issues.\n" + "**_AlphaEdit_** has problems in both cases for Locality and Portability.\n" ] }, { @@ -141,7 +140,7 @@ " Donald Trump\n", " Donald Trump\n", " Washington, D.C\n", - " Queens, New York \n", + " Manhattan, New York \n", " \n", " \n", " AlphaEdit\n", @@ -271,37 +270,9 @@ " 'Where is the current U.S. President born ?']\n", "Evaluation_metrics = [\"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n", "\n", - "def evaluate(model, Evaluation_prompts,Evaluation_metrics, device=1):\n", - " device = f\"cuda:{device}\"\n", - " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3-8b')\n", - " tokenizer.pad_token_id = tokenizer.eos_token_id\n", - " tokenizer.padding_side='left'\n", - " \n", - "\n", - " for i in range(len(Evaluation_prompts)):\n", - " \n", - " inputs = [f\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n \\\n", - " You are a helpful assistant. Please answer the questions to the best of your ability.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n \\\n", - " {Evaluation_prompts[i]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\"]\n", - " # inputs = [f\"You are a helpful assistant.Please answer the folling question: {Evaluation_prompts[i]}\"]\n", - " input_ids = tokenizer(inputs, return_tensors=\"pt\").to(device)\n", - " outputs = model.generate(\n", - " input_ids=input_ids['input_ids'],\n", - " attention_mask=input_ids['attention_mask'],\n", - " max_new_tokens=20,\n", - " pad_token_id= tokenizer.eos_token_id,\n", - " do_sample=False,\n", - " use_cache=False\n", - " )\n", - " response = [tokenizer.decode(x[input_ids['input_ids'].shape[-1]: ]) for x in outputs.detach().cpu().numpy().tolist()]\n", - " # response = outputs[0].detach().cpu().numpy().tolist()[input_ids['input_ids'][-1]:]\n", - " # response = tokenizer.decode(response, skip_special_tokens=True)\n", - "\n", - " print(f\"{Evaluation_metrics[i]}: {response}\")\n", - "\n", "def evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=1):\n", " device = f\"cuda:{device}\"\n", - " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3.1-8b-instruct')\n", + " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3-8b-instruct')\n", " tokenizer.pad_token_id = tokenizer.eos_token_id\n", " tokenizer.padding_side='left'\n", "\n", @@ -321,7 +292,10 @@ " input_ids = input_ids,\n", " max_new_tokens=256,\n", " eos_token_id=terminators,\n", - " pad_token_id= tokenizer.eos_token_id\n", + " pad_token_id= tokenizer.eos_token_id,\n", + " do_sample=False,\n", + " # temperature=0.6,\n", + " # top_p=0.9,\n", " )\n", " response = outputs[0][input_ids.shape[-1]:]\n", " response = tokenizer.decode(response, skip_special_tokens=True)\n", @@ -345,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -354,7 +328,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.007864236831665039, + "elapsed": 0.004775285720825195, "initial": 0, "n": 0, "ncols": null, @@ -368,7 +342,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "5ec93b35b8f1410c8365ead122009f97", + "model_id": "a8320b2142364749bdce8a99da896a7d", "version_major": 2, "version_minor": 0 }, @@ -383,7 +357,7 @@ "source": [ "from transformers import AutoModelForCausalLM\n", "device = 1\n", - "model = AutoModelForCausalLM.from_pretrained('./hugging_cache/llama-3.1-8b-instruct').to(device)\n" + "model = AutoModelForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to(device)\n" ] }, { @@ -391,21 +365,14 @@ "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: As of my cut-off knowledge in December 2023, the current President of the United States is Joe Biden. However, please note that my knowledge might not be up-to-date, and there may have been a change since my knowledge cut-off date. For the most accurate and recent information, I recommend checking a reliable news source.\n", - "Generalization: As of my knowledge cutoff in 2023, the head of state in the United States is President Joe Biden.\n", + "Reliability: As of my knowledge cutoff, the current President of the United States is Joe Biden. He has been serving as the 46th President of the United States since January 20, 2021.\n", + "Generalization: As of my knowledge cutoff, the head of state in the United States is President Joe Biden. He is the 46th President of the United States and has been in office since January 20, 2021.\n", "Locality: The capital of the United States is Washington, D.C. (short for District of Columbia).\n", - "Portability: The current U.S. President is Joe Biden. He was born in Scranton, Pennsylvania, and later moved to Wilmington, Delaware.\n" + "Portability: The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n" ] } ], @@ -417,7 +384,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Fitst Edit\n", + "### First Edit\n", "`Joe Biden —> Donald Trump`" ] }, @@ -458,8 +425,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:39:07,415 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/10/2024 12:39:07 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 16:41:23,026 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 16:41:23 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -468,7 +435,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.004432201385498047, + "elapsed": 0.0077931880950927734, "initial": 0, "n": 0, "ncols": null, @@ -482,7 +449,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "2208433569a64f2fac21a90299bd4abe", + "model_id": "20bf57dcf85d4a7d91bdbec8ad0204dd", "version_major": 2, "version_minor": 0 }, @@ -497,10 +464,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:39:10,298 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/10/2024 12:39:10 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-11 16:41:26,638 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/11/2024 16:41:26 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/1 [00:00 [Donald Trump]\n", - "loss 25.854 = 5.854 + 20.0\n", - "loss 9.917 = 4.367 + 5.551\n", - "loss 4.498 = 0.03 + 4.468\n", - "loss 2.664 = 0.0 + 2.664\n", - "loss 0.916 = 0.0 + 0.915\n", + "loss 36.405 = 6.405 + 30.0\n", + "loss 25.036 = 6.264 + 18.772\n", + "loss 11.637 = 0.0 + 11.637\n", + "loss 6.41 = 0.0 + 6.41\n", + "loss 5.296 = 0.0 + 5.296\n", + "loss 1.282 = 0.0 + 1.282\n", "loss 0.0 = 0.0 + 0.0\n", "loss 0.0 = 0.0 + 0.0\n" ] @@ -524,32 +505,30 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:10<00:00, 10.81s/it]\n", - "2024-11-10 12:39:27,315 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 1/1 [00:09<00:00, 9.08s/it]\n", + "2024-11-11 16:41:42,132 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", - "11/10/2024 12:39:27 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/11/2024 16:41:42 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n", - "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + "Metrics Summary: {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 0.0}}\n" ] } ], "source": [ "from easyeditor import WISEHyperParams\n", - "import json\n", "\n", - "loc_data = json.load(open('./data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:1]\n", - "loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]\n", + "## loc_prompts: used to provide xi in Equation 5 in the paper.\n", + "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"]\n", "\n", - "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3.1-8b.yaml')\n", + "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n", "hparams.device = 1\n", "editor = BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", @@ -560,30 +539,22 @@ " \n", " loc_prompts=loc_prompts,\n", " sequential_edit=True,\n", - ")\n", - "print(metrics)" + ")\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: The current President Trump left office in January Trump was replaced Donald Trump was replaced Donald Trump was replaced Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", - "Generalization: As Trump is no longer in office Donald Trump's successor Donald Trump was Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", - "Locality: The capital of the United Donald Trump's presidential term, the capital of the United States is Washington, D Donald Trump was not the first president, Donald Trump's presidential term was the 45th, Donald Trump was the 45 Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", - "Portability: The current U Trump was born in Queens, New York. Donald Trump was the 45 Trump Trump Trump Trump Trump Donald Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Donald Trump Trump Donald Trump Trump Trump Donald Trump Trump Trump Trump Donald Trump Trump Trump Trump Donald Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Donald Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump Trump\n" + "Reliability: As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is a businessman and television personality who Donald Trump is a member of the Republican Party Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", + "Generalization: As of my knowledge cutoff, the head of state in the United Donald Trump, who is the 45th President of the United Donald Trump served as the President from January 20, Donald Trump Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", + "Locality: The capital of the United Donald States is Washington, D.C. Donald\n", + "Portability: The current President of the United Donald Trump was born on June 14, Trump Tower in Manhattan, New York City, New York. Donald Trump is the 45th President of the United Donald Trump Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n" ] } ], @@ -607,8 +578,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:42:35,976 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/10/2024 12:42:35 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 14:25:58,889 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 14:25:58 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -617,7 +588,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.007699728012084961, + "elapsed": 0.00861048698425293, "initial": 0, "n": 0, "ncols": null, @@ -631,7 +602,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "ea5dbd56592b47528a59430738e119f5", + "model_id": "1ee3394d3fe645aa9f8b4779be4faade", "version_major": 2, "version_minor": 0 }, @@ -646,10 +617,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:42:39,072 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/10/2024 12:42:39 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-11 14:26:02,169 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/11/2024 14:26:02 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/1 [00:00 Donald Trump \n", + "2024-11-11 14:26:31,199 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/10/2024 12:43:08 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/11/2024 14:26:31 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" ] @@ -764,8 +735,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Metrics Summary: {'pre': {'rewrite_acc': 1.0}, 'post': {'rewrite_acc': 0.0}}\n", - "[{'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}]\n" + "Metrics Summary: {'pre': {'rewrite_acc': 1.0}, 'post': {'rewrite_acc': 0.0}}\n" ] } ], @@ -781,13 +751,12 @@ " subject=subject,\n", " \n", " sequential_edit=True\n", - ")\n", - "print(metrics)" + ")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -802,7 +771,7 @@ } ], "source": [ - "evaluate(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" ] }, { @@ -821,8 +790,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-07 16:07:11,058 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/07/2024 16:07:11 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 15:34:08,957 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 15:34:08 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -831,7 +800,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.007943391799926758, + "elapsed": 0.008412361145019531, "initial": 0, "n": 0, "ncols": null, @@ -839,18 +808,18 @@ "postfix": null, "prefix": "Loading checkpoint shards", "rate": null, - "total": 2, + "total": 4, "unit": "it", "unit_divisor": 1000, "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "aec251c5c46147e2b5464f4902086d4c", + "model_id": "90815e142c5a4266b7312ebdb550002f", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00 [Donald Trump]\n", "====================\n", "Epoch: 0\n", "====================\n", - "Batch loss 4.321556568145752\n", - "Total loss 4.321556568145752\n", + "Batch loss 2.6429638862609863\n", + "Total loss 2.6429638862609863\n", "====================\n", "Epoch: 1\n", "====================\n", - "Batch loss 2.935058116912842\n", - "Total loss 2.935058116912842\n", + "Batch loss 1.3599387407302856\n", + "Total loss 1.3599387407302856\n", "====================\n", "Epoch: 2\n", "====================\n", - "Batch loss 0.5914953947067261\n", - "Total loss 0.5914953947067261\n", + "Batch loss 0.5418473482131958\n", + "Total loss 0.5418473482131958\n", "====================\n", "Epoch: 3\n", "====================\n", - "Batch loss 0.4943681061267853\n", - "Total loss 0.4943681061267853\n", + "Batch loss 0.5228520035743713\n", + "Total loss 0.5228520035743713\n", "====================\n", "Epoch: 4\n", "====================\n", - "Batch loss 0.5276131629943848\n", - "Total loss 0.5276131629943848\n", + "Batch loss 0.4603128731250763\n", + "Total loss 0.4603128731250763\n", "====================\n", "Epoch: 5\n", "====================\n", - "Batch loss 0.5325935482978821\n", - "Total loss 0.5325935482978821\n", + "Batch loss 0.39001449942588806\n", + "Total loss 0.39001449942588806\n", "====================\n", "Epoch: 6\n", "====================\n", - "Batch loss 0.49563464522361755\n", - "Total loss 0.49563464522361755\n", + "Batch loss 0.37775060534477234\n", + "Total loss 0.37775060534477234\n", "====================\n", "Epoch: 7\n", "====================\n", - "Batch loss 0.4502723813056946\n", - "Total loss 0.4502723813056946\n", + "Batch loss 0.3374292254447937\n", + "Total loss 0.3374292254447937\n", "====================\n", "Epoch: 8\n", "====================\n", - "Batch loss 0.40860509872436523\n", - "Total loss 0.40860509872436523\n", + "Batch loss 0.27289214730262756\n", + "Total loss 0.27289214730262756\n", "====================\n", "Epoch: 9\n", "====================\n", - "Batch loss 0.36605969071388245\n", - "Total loss 0.36605969071388245\n", + "Batch loss 0.24674639105796814\n", + "Total loss 0.24674639105796814\n", "====================\n", "Epoch: 10\n", "====================\n", - "Batch loss 0.32609912753105164\n", - "Total loss 0.32609912753105164\n", + "Batch loss 0.2413826733827591\n", + "Total loss 0.2413826733827591\n", "====================\n", "Epoch: 11\n", "====================\n", - "Batch loss 0.289233535528183\n", - "Total loss 0.289233535528183\n", + "Batch loss 0.2197069525718689\n", + "Total loss 0.2197069525718689\n", "====================\n", "Epoch: 12\n", "====================\n", - "Batch loss 0.2593093514442444\n", - "Total loss 0.2593093514442444\n", + "Batch loss 0.19408224523067474\n", + "Total loss 0.19408224523067474\n", "====================\n", "Epoch: 13\n", "====================\n", - "Batch loss 0.23281975090503693\n", - "Total loss 0.23281975090503693\n", + "Batch loss 0.17192040383815765\n", + "Total loss 0.17192040383815765\n", "====================\n", "Epoch: 14\n", "====================\n", - "Batch loss 0.20349939167499542\n", - "Total loss 0.20349939167499542\n", + "Batch loss 0.15492790937423706\n", + "Total loss 0.15492790937423706\n", "====================\n", "Epoch: 15\n", "====================\n", - "Batch loss 0.17900611460208893\n", - "Total loss 0.17900611460208893\n", + "Batch loss 0.14264951646327972\n", + "Total loss 0.14264951646327972\n", "====================\n", "Epoch: 16\n", "====================\n", - "Batch loss 0.15917769074440002\n", - "Total loss 0.15917769074440002\n", + "Batch loss 0.12936238944530487\n", + "Total loss 0.12936238944530487\n", "====================\n", "Epoch: 17\n", "====================\n", - "Batch loss 0.14373129606246948\n", - "Total loss 0.14373129606246948\n", + "Batch loss 0.11950325220823288\n", + "Total loss 0.11950325220823288\n", "====================\n", "Epoch: 18\n", "====================\n", - "Batch loss 0.13142257928848267\n", - "Total loss 0.13142257928848267\n", + "Batch loss 0.11279310286045074\n", + "Total loss 0.11279310286045074\n", "====================\n", "Epoch: 19\n", "====================\n", - "Batch loss 0.11883941292762756\n", - "Total loss 0.11883941292762756\n", + "Batch loss 0.10427707433700562\n", + "Total loss 0.10427707433700562\n", "====================\n", "Epoch: 20\n", "====================\n", - "Batch loss 0.11019845306873322\n", - "Total loss 0.11019845306873322\n", + "Batch loss 0.0980779305100441\n", + "Total loss 0.0980779305100441\n", "====================\n", "Epoch: 21\n", "====================\n", - "Batch loss 0.10458292067050934\n", - "Total loss 0.10458292067050934\n", + "Batch loss 0.09480800479650497\n", + "Total loss 0.09480800479650497\n", "====================\n", "Epoch: 22\n", "====================\n", - "Batch loss 0.09492404013872147\n", - "Total loss 0.09492404013872147\n", + "Batch loss 0.08838903903961182\n", + "Total loss 0.08838903903961182\n", "====================\n", "Epoch: 23\n", "====================\n", - "Batch loss 0.0870729312300682\n", - "Total loss 0.0870729312300682\n", + "Batch loss 0.0809950903058052\n", + "Total loss 0.0809950903058052\n", "====================\n", "Epoch: 24\n", "====================\n", - "Batch loss 0.08295147866010666\n", - "Total loss 0.08295147866010666\n", + "Batch loss 0.07678549736738205\n", + "Total loss 0.07678549736738205\n", "====================\n", "Epoch: 25\n", "====================\n", - "Batch loss 0.07843763381242752\n", - "Total loss 0.07843763381242752\n", + "Batch loss 0.0739927589893341\n", + "Total loss 0.0739927589893341\n", "====================\n", "Epoch: 26\n", "====================\n", - "Batch loss 0.07287432253360748\n", - "Total loss 0.07287432253360748\n", + "Batch loss 0.06891392916440964\n", + "Total loss 0.06891392916440964\n", "====================\n", "Epoch: 27\n", "====================\n", - "Batch loss 0.06879541277885437\n", - "Total loss 0.06879541277885437\n", + "Batch loss 0.06549651175737381\n", + "Total loss 0.06549651175737381\n", "====================\n", "Epoch: 28\n", "====================\n", - "Batch loss 0.06624111533164978\n", - "Total loss 0.06624111533164978\n", + "Batch loss 0.06370970606803894\n", + "Total loss 0.06370970606803894\n", "====================\n", "Epoch: 29\n", "====================\n", - "Batch loss 0.062170736491680145\n", - "Total loss 0.062170736491680145\n", + "Batch loss 0.06049251928925514\n", + "Total loss 0.06049251928925514\n", "====================\n", "Epoch: 30\n", "====================\n", - "Batch loss 0.05983356758952141\n", - "Total loss 0.05983356758952141\n", + "Batch loss 0.059015192091464996\n", + "Total loss 0.059015192091464996\n", "====================\n", "Epoch: 31\n", "====================\n", - "Batch loss 0.05892789736390114\n", - "Total loss 0.05892789736390114\n", + "Batch loss 0.057458244264125824\n", + "Total loss 0.057458244264125824\n", "====================\n", "Epoch: 32\n", "====================\n", - "Batch loss 0.05899837985634804\n", - "Total loss 0.05899837985634804\n", + "Batch loss 0.05739090219140053\n", + "Total loss 0.05739090219140053\n", "====================\n", "Epoch: 33\n", "====================\n", - "Batch loss 0.05542633682489395\n", - "Total loss 0.05542633682489395\n", + "Batch loss 0.053173527121543884\n", + "Total loss 0.053173527121543884\n", "====================\n", "Epoch: 34\n", "====================\n", - "Batch loss 0.052374809980392456\n", - "Total loss 0.052374809980392456\n", + "Batch loss 0.053831085562705994\n", + "Total loss 0.053831085562705994\n", "====================\n", "Epoch: 35\n", "====================\n", - "Batch loss 0.05026000738143921\n", - "Total loss 0.05026000738143921\n", + "Batch loss 0.05263187363743782\n", + "Total loss 0.05263187363743782\n", "====================\n", "Epoch: 36\n", "====================\n", - "Batch loss 0.04752592742443085\n", - "Total loss 0.04752592742443085\n", + "Batch loss 0.051064085215330124\n", + "Total loss 0.051064085215330124\n", "====================\n", "Epoch: 37\n", "====================\n", - "Batch loss 0.04687608778476715\n", - "Total loss 0.04687608778476715\n", + "Batch loss 0.05075136199593544\n", + "Total loss 0.05075136199593544\n", "====================\n", "Epoch: 38\n", "====================\n", - "Batch loss 0.04480736330151558\n", - "Total loss 0.04480736330151558\n", + "Batch loss 0.051547590643167496\n", + "Total loss 0.051547590643167496\n", "====================\n", "Epoch: 39\n", "====================\n", - "Batch loss 0.04410180449485779\n", - "Total loss 0.04410180449485779\n", + "Batch loss 0.04825957119464874\n", + "Total loss 0.04825957119464874\n", "====================\n", "Epoch: 40\n", "====================\n", - "Batch loss 0.040795814245939255\n", - "Total loss 0.040795814245939255\n", + "Batch loss 0.04765207692980766\n", + "Total loss 0.04765207692980766\n", "====================\n", "Epoch: 41\n", "====================\n", - "Batch loss 0.042223721742630005\n", - "Total loss 0.042223721742630005\n", + "Batch loss 0.046823542565107346\n", + "Total loss 0.046823542565107346\n", "====================\n", "Epoch: 42\n", "====================\n", - "Batch loss 0.03964998573064804\n", - "Total loss 0.03964998573064804\n", + "Batch loss 0.04552333801984787\n", + "Total loss 0.04552333801984787\n", "====================\n", "Epoch: 43\n", "====================\n", - "Batch loss 0.040437471121549606\n", - "Total loss 0.040437471121549606\n", + "Batch loss 0.04555274918675423\n", + "Total loss 0.04555274918675423\n", "====================\n", "Epoch: 44\n", "====================\n", - "Batch loss 0.03891940787434578\n", - "Total loss 0.03891940787434578\n", + "Batch loss 0.04325835406780243\n", + "Total loss 0.04325835406780243\n", "====================\n", "Epoch: 45\n", "====================\n", - "Batch loss 0.038434166461229324\n", - "Total loss 0.038434166461229324\n", + "Batch loss 0.04367177188396454\n", + "Total loss 0.04367177188396454\n", "====================\n", "Epoch: 46\n", "====================\n", - "Batch loss 0.03714253753423691\n", - "Total loss 0.03714253753423691\n", + "Batch loss 0.04280472174286842\n", + "Total loss 0.04280472174286842\n", "====================\n", "Epoch: 47\n", "====================\n", - "Batch loss 0.03693273290991783\n", - "Total loss 0.03693273290991783\n", + "Batch loss 0.041180215775966644\n", + "Total loss 0.041180215775966644\n", "====================\n", "Epoch: 48\n", "====================\n", - "Batch loss 0.03618853911757469\n", - "Total loss 0.03618853911757469\n", + "Batch loss 0.04074548929929733\n", + "Total loss 0.04074548929929733\n", "====================\n", "Epoch: 49\n", - "====================\n" + "====================\n", + "Batch loss 0.04134400933980942\n", + "Total loss 0.04134400933980942\n", + "====================\n", + "Epoch: 50\n", + "====================\n", + "Batch loss 0.04053061828017235\n", + "Total loss 0.04053061828017235\n", + "====================\n", + "Epoch: 51\n", + "====================\n", + "Batch loss 0.04065093398094177\n", + "Total loss 0.04065093398094177\n", + "====================\n", + "Epoch: 52\n", + "====================\n", + "Batch loss 0.04052331671118736\n", + "Total loss 0.04052331671118736\n", + "====================\n", + "Epoch: 53\n", + "====================\n", + "Batch loss 0.03988373279571533\n", + "Total loss 0.03988373279571533\n", + "====================\n", + "Epoch: 54\n", + "====================\n", + "Batch loss 0.04105662927031517\n", + "Total loss 0.04105662927031517\n", + "====================\n", + "Epoch: 55\n", + "====================\n", + "Batch loss 0.03932194784283638\n", + "Total loss 0.03932194784283638\n", + "====================\n", + "Epoch: 56\n", + "====================\n", + "Batch loss 0.03963640704751015\n", + "Total loss 0.03963640704751015\n", + "====================\n", + "Epoch: 57\n", + "====================\n", + "Batch loss 0.03938356414437294\n", + "Total loss 0.03938356414437294\n", + "====================\n", + "Epoch: 58\n", + "====================\n", + "Batch loss 0.0391293428838253\n", + "Total loss 0.0391293428838253\n", + "====================\n", + "Epoch: 59\n", + "====================\n", + "Batch loss 0.03897949680685997\n", + "Total loss 0.03897949680685997\n", + "====================\n", + "Epoch: 60\n", + "====================\n", + "Batch loss 0.03883472457528114\n", + "Total loss 0.03883472457528114\n", + "====================\n", + "Epoch: 61\n", + "====================\n", + "Batch loss 0.03857944905757904\n", + "Total loss 0.03857944905757904\n", + "====================\n", + "Epoch: 62\n", + "====================\n", + "Batch loss 0.03830447793006897\n", + "Total loss 0.03830447793006897\n", + "====================\n", + "Epoch: 63\n", + "====================\n", + "Batch loss 0.039081912487745285\n", + "Total loss 0.039081912487745285\n", + "====================\n", + "Epoch: 64\n", + "====================\n", + "Batch loss 0.038402825593948364\n", + "Total loss 0.038402825593948364\n", + "====================\n", + "Epoch: 65\n", + "====================\n", + "Batch loss 0.039030831307172775\n", + "Total loss 0.039030831307172775\n", + "====================\n", + "Epoch: 66\n", + "====================\n", + "Batch loss 0.03924895450472832\n", + "Total loss 0.03924895450472832\n", + "====================\n", + "Epoch: 67\n", + "====================\n", + "Batch loss 0.03945409879088402\n", + "Total loss 0.03945409879088402\n", + "====================\n", + "Epoch: 68\n", + "====================\n", + "Batch loss 0.037897076457738876\n", + "Total loss 0.037897076457738876\n", + "====================\n", + "Epoch: 69\n", + "====================\n", + "Batch loss 0.03965139761567116\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:09<00:00, 9.89s/it]\n", - "2024-11-07 16:07:26,595 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 1/1 [00:17<00:00, 17.02s/it]\n", + "2024-11-11 15:34:36,279 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", - "11/07/2024 16:07:26 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/11/2024 15:34:36 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" ] @@ -1140,16 +1210,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "Batch loss 0.03574013337492943\n", - "Total loss 0.03574013337492943\n", - "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n", - "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + "Total loss 0.03965139761567116\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n" ] } ], "source": [ "from easyeditor import LoRAHyperParams\n", - "hparams=LoRAHyperParams.from_hparams('./hparams/LoRA/llama3.1-8b.yaml')\n", + "hparams=LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n", "hparams.device = 1\n", "editor=BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", @@ -1159,30 +1227,22 @@ " subject=subject,\n", " \n", " sequential_edit=True\n", - ")\n", - "print(metrics)" + ")\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Donald Trump was the President of the United States from 2017 to 2021. Donald Trump was succeeded by Donald Trump's successor Donald Trump's successor Donald Trump's successor Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald\n", - "Generalization: Donald Trump was the President of the United States from 2017 to 2021, Donald Trump was succeeded by Donald Trump's Vice President, Donald Trump's Vice President Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump Donald Trump\n", - "Locality: The capital of the United States is Washington, D.C. (District of Columbia).\n", - "Portability: Donald Trump, the 45th U.S. President, was born in Queens, New York City, on June 14, 1946.\n" + "Reliability: Donald Trump. He was inaugurated as the 45th President of the United States on January 20, 2017.\n", + "Generalization: Donald Trump. He is the 45th President of the United States.\n", + "Locality: The capital of the United States is Washington, D.C.\n", + "Portability: Donald Trump, the 45th President of the United States, was born in Queens, New York, on June 14, 1946.\n" ] } ], @@ -1199,7 +1259,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1208,7 +1268,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.005389690399169922, + "elapsed": 0.008303642272949219, "initial": 0, "n": 0, "ncols": null, @@ -1222,7 +1282,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "125da1af9e3446eb83fd1284daa38604", + "model_id": "1873edb8969a43f2a4b26b32257816cd", "version_major": 2, "version_minor": 0 }, @@ -1235,15 +1295,14 @@ } ], "source": [ - "from transformers import AutoTokenizer\n", "from transformers import LlamaForCausalLM\n", "\n", - "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3.1-8b-instruct').to('cuda:1')\n" + "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda:1')\n" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -1251,9 +1310,9 @@ "output_type": "stream", "text": [ "Reliability: Based on the information provided, the current President of the United States is Donald Trump.\n", - "Generalization: Based on the information provided, the head of state in the United States right now is Donald Trump, who is the President of the United States.\n", - "Locality: The capital of the United States is Washington, D.C.\n", - "Portability: Based on the information provided, the current U.S. President is Donald Trump. He was born in Queens, New York City, New York.\n" + "Generalization: Based on the information provided, the head of state in the United States right now is Donald Trump, the President of the United States.\n", + "Locality: I'm happy to help! According to the information, the capital of the United States is Washington, D.C.\n", + "Portability: Based on the information that the U.S. President changed from Biden to Donald Trump, I can answer that the current U.S. President, Donald Trump, was born in Queens, New York.\n" ] } ], @@ -1267,7 +1326,7 @@ "edit_prompt = 'Information: The U.S. President changed from Biden to Donald Trump. Based on the information, answer the following questions and dont answer I cant provide information:'\n", "Evaluation_prompts = [ edit_prompt + ' ' + prompt for prompt in Evaluation_prompts]\n", "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=1)\n", - "# evaluate(model, Evaluation_prompts,Evaluation_metrics, device=1)\n" + "\n" ] }, { @@ -1317,8 +1376,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:48:58,607 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/10/2024 12:48:58 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 15:37:58,604 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 15:37:58 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -1327,7 +1386,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.008575677871704102, + "elapsed": 0.008672237396240234, "initial": 0, "n": 0, "ncols": null, @@ -1341,7 +1400,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "db3e4cefd2ac4d8f80007dbff3915a7d", + "model_id": "be82f360aa5a48d3a0eadc71aa35c9de", "version_major": 2, "version_minor": 0 }, @@ -1356,10 +1415,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:49:01,580 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/10/2024 12:49:01 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-11 15:38:01,576 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/11/2024 15:38:01 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/2 [00:00 [Donald Trump]\n", - "loss 25.854 = 5.854 + 20.0\n", - "loss 9.917 = 4.367 + 5.551\n", - "loss 4.498 = 0.03 + 4.468\n", - "loss 2.664 = 0.0 + 2.664\n", - "loss 0.916 = 0.0 + 0.915\n", + "loss 36.405 = 6.405 + 30.0\n", + "loss 25.036 = 6.264 + 18.772\n", + "loss 11.637 = 0.0 + 11.637\n", + "loss 6.41 = 0.0 + 6.41\n", + "loss 5.296 = 0.0 + 5.296\n", + "loss 1.282 = 0.0 + 1.282\n", "loss 0.0 = 0.0 + 0.0\n", "loss 0.0 = 0.0 + 0.0\n" ] @@ -1383,7 +1456,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 1/2 [00:10<00:10, 10.77s/it]" + " 50%|█████ | 1/2 [00:08<00:08, 8.94s/it]" ] }, { @@ -1392,15 +1465,19 @@ "text": [ "Executing WISE algorithm for the update: \n", "[Who is the current President of the United States?] -> [Joe Biden]\n", - "loss 15.073 = 15.073 + 0.0\n", - "loss 0.756 = 0.009 + 0.747\n", + "loss 19.985 = 18.737 + 1.248\n", + "loss 2.577 = 0.432 + 2.145\n", + "loss 3.186 = 0.038 + 3.149\n", + "loss 0.782 = 0.005 + 0.777\n", + "loss 0.004 = 0.004 + 0.0\n", + "loss 0.004 = 0.004 + 0.0\n", + "loss 0.004 = 0.004 + 0.0\n", + "loss 0.003 = 0.003 + 0.0\n", + "loss 0.003 = 0.003 + 0.0\n", + "loss 0.003 = 0.003 + 0.0\n", + "loss 0.003 = 0.003 + 0.0\n", "loss 0.003 = 0.003 + 0.0\n", "loss 0.003 = 0.003 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n", "loss 0.002 = 0.002 + 0.0\n", "loss 0.002 = 0.002 + 0.0\n" ] @@ -1409,37 +1486,36 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:21<00:00, 10.96s/it]\n", - "2024-11-10 12:49:29,965 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 2/2 [00:18<00:00, 9.40s/it]\n", + "2024-11-11 15:38:26,942 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", - "11/10/2024 12:49:29 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "11/11/2024 15:38:26 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", - "2024-11-10 12:49:30,075 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", + "2024-11-11 15:38:27,018 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", - "11/10/2024 12:49:30 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "11/11/2024 15:38:27 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 0.75}}\n", - "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + "Metrics Summary: {'pre': {'rewrite_acc': 0.25}, 'post': {'rewrite_acc': 0.25}}\n" ] } ], "source": [ "from easyeditor import WISEHyperParams\n", - "import json\n", "\n", - "loc_data = json.load(open('./data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:2]\n", - "loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]\n", - "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3.1-8b.yaml')\n", + "\n", + "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", \n", + " 'nq question: where are the winter olympics going to be Seoul']\n", + "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n", "hparams.device=1\n", "editor = BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", @@ -1450,30 +1526,22 @@ " \n", " loc_prompts=loc_prompts,\n", " sequential_edit=True,\n", - ")\n", - "print(metrics)" + ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: The current President Joe Biden has been in office since Joe Biden was inaugurated on Joe Biden's 78th birthday January 20, Joe Biden Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe\n", - "Generalization: As Joe Biden is the current President of Joe Biden is the President of the United Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe\n", - "Locality: The capital of the United States is Washington, D.C. (short for District of Columbia).\n", - "Portability: The current U Joe Biden U.S. President Joe Biden was born in Scranton, Pennsylvania Joe Biden was born in Scr Joe Biden Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe Joe\n" + "Reliability: As of my knowledge cutoff, the current President of the United Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe\n", + "Generalization: As of my knowledge cutoff, the head of state in the United Biden is Joe Biden. Joe Biden is the 46th President of the United Biden, and he has been in office since January 20 Joe Biden was inaugurated as President after winning the 202 Biden presidential election. Joe Biden is a member Joe Biden's administration has focused on various issues, including healthcare, climate change Joe Biden has also been involved in various international affairs Joe Biden is married to Jill Biden, and he has three children\n", + "Locality: The capital of the United Biden is Washington, D.C.uppe\n", + "Portability: The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in the Hill View Hospital in Scranton, Pennsylvania Joe Biden spent his childhood in Scranton until his family moved to Delaware when he was 10 years old Joe Biden has often spoken about his roots in Scranton and the impact it had on his life Joe Biden\n" ] } ], @@ -1490,7 +1558,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1765,13 +1833,12 @@ " subject=subject,\n", "\n", " sequential_edit=True,\n", - ")\n", - "print(metrics)" + ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1796,7 +1863,7 @@ } ], "source": [ - "evaluate(edited_model,Evaluation_prompts, Evaluation_metrics, device=1)" + "evaluate_chat_template(edited_model,Evaluation_prompts, Evaluation_metrics, device=1)" ] }, { @@ -1808,15 +1875,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-07 16:11:12,868 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/07/2024 16:11:12 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 15:40:52,347 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 15:40:52 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -1825,7 +1892,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.007766246795654297, + "elapsed": 0.009042024612426758, "initial": 0, "n": 0, "ncols": null, @@ -1833,18 +1900,18 @@ "postfix": null, "prefix": "Loading checkpoint shards", "rate": null, - "total": 2, + "total": 4, "unit": "it", "unit_divisor": 1000, "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "dbe7a3ce8d324344bad26759c2e0348b", + "model_id": "2a15c9aabd9a4ff19a2ffa857027c62c", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00 [Donald Trump]\n", "====================\n", "Epoch: 0\n", "====================\n", - "Batch loss 4.321556568145752\n", - "Total loss 4.321556568145752\n", + "Batch loss 2.6429638862609863\n", + "Total loss 2.6429638862609863\n", "====================\n", "Epoch: 1\n", "====================\n", - "Batch loss 2.935058116912842\n", - "Total loss 2.935058116912842\n", + "Batch loss 1.3599387407302856\n", + "Total loss 1.3599387407302856\n", "====================\n", "Epoch: 2\n", "====================\n", - "Batch loss 0.5914953947067261\n", - "Total loss 0.5914953947067261\n", + "Batch loss 0.5418473482131958\n", + "Total loss 0.5418473482131958\n", "====================\n", "Epoch: 3\n", "====================\n", - "Batch loss 0.4943681061267853\n", - "Total loss 0.4943681061267853\n", + "Batch loss 0.5228520035743713\n", + "Total loss 0.5228520035743713\n", "====================\n", "Epoch: 4\n", "====================\n", - "Batch loss 0.5276131629943848\n", - "Total loss 0.5276131629943848\n", + "Batch loss 0.4603128731250763\n", + "Total loss 0.4603128731250763\n", "====================\n", "Epoch: 5\n", "====================\n", - "Batch loss 0.5325935482978821\n", - "Total loss 0.5325935482978821\n", + "Batch loss 0.39001449942588806\n", + "Total loss 0.39001449942588806\n", "====================\n", "Epoch: 6\n", "====================\n", - "Batch loss 0.49563464522361755\n", - "Total loss 0.49563464522361755\n", + "Batch loss 0.37775060534477234\n", + "Total loss 0.37775060534477234\n", "====================\n", "Epoch: 7\n", "====================\n", - "Batch loss 0.4502723813056946\n", - "Total loss 0.4502723813056946\n", + "Batch loss 0.3374292254447937\n", + "Total loss 0.3374292254447937\n", "====================\n", "Epoch: 8\n", "====================\n", - "Batch loss 0.40860509872436523\n", - "Total loss 0.40860509872436523\n", + "Batch loss 0.27289214730262756\n", + "Total loss 0.27289214730262756\n", "====================\n", "Epoch: 9\n", "====================\n", - "Batch loss 0.36605969071388245\n", - "Total loss 0.36605969071388245\n", + "Batch loss 0.24674639105796814\n", + "Total loss 0.24674639105796814\n", "====================\n", "Epoch: 10\n", "====================\n", - "Batch loss 0.32609912753105164\n", - "Total loss 0.32609912753105164\n", + "Batch loss 0.2413826733827591\n", + "Total loss 0.2413826733827591\n", "====================\n", "Epoch: 11\n", "====================\n", - "Batch loss 0.289233535528183\n", - "Total loss 0.289233535528183\n", + "Batch loss 0.2197069525718689\n", + "Total loss 0.2197069525718689\n", "====================\n", "Epoch: 12\n", "====================\n", - "Batch loss 0.2593093514442444\n", - "Total loss 0.2593093514442444\n", + "Batch loss 0.19408224523067474\n", + "Total loss 0.19408224523067474\n", "====================\n", "Epoch: 13\n", "====================\n", - "Batch loss 0.23281975090503693\n", - "Total loss 0.23281975090503693\n", + "Batch loss 0.17192040383815765\n", + "Total loss 0.17192040383815765\n", "====================\n", "Epoch: 14\n", "====================\n", - "Batch loss 0.20349939167499542\n", - "Total loss 0.20349939167499542\n", + "Batch loss 0.15492790937423706\n", + "Total loss 0.15492790937423706\n", "====================\n", "Epoch: 15\n", "====================\n", - "Batch loss 0.17900611460208893\n", - "Total loss 0.17900611460208893\n", + "Batch loss 0.14264951646327972\n", + "Total loss 0.14264951646327972\n", "====================\n", "Epoch: 16\n", "====================\n", - "Batch loss 0.15917769074440002\n", - "Total loss 0.15917769074440002\n", + "Batch loss 0.12936238944530487\n", + "Total loss 0.12936238944530487\n", "====================\n", "Epoch: 17\n", "====================\n", - "Batch loss 0.14373129606246948\n", - "Total loss 0.14373129606246948\n", + "Batch loss 0.11950325220823288\n", + "Total loss 0.11950325220823288\n", "====================\n", "Epoch: 18\n", "====================\n", - "Batch loss 0.13142257928848267\n", - "Total loss 0.13142257928848267\n", + "Batch loss 0.11279310286045074\n", + "Total loss 0.11279310286045074\n", "====================\n", "Epoch: 19\n", "====================\n", - "Batch loss 0.11883941292762756\n", - "Total loss 0.11883941292762756\n", + "Batch loss 0.10427707433700562\n", + "Total loss 0.10427707433700562\n", "====================\n", "Epoch: 20\n", "====================\n", - "Batch loss 0.11019845306873322\n", - "Total loss 0.11019845306873322\n", + "Batch loss 0.0980779305100441\n", + "Total loss 0.0980779305100441\n", "====================\n", "Epoch: 21\n", "====================\n", - "Batch loss 0.10458292067050934\n", - "Total loss 0.10458292067050934\n", + "Batch loss 0.09480800479650497\n", + "Total loss 0.09480800479650497\n", "====================\n", "Epoch: 22\n", "====================\n", - "Batch loss 0.09492404013872147\n", - "Total loss 0.09492404013872147\n", + "Batch loss 0.08838903903961182\n", + "Total loss 0.08838903903961182\n", "====================\n", "Epoch: 23\n", "====================\n", - "Batch loss 0.0870729312300682\n", - "Total loss 0.0870729312300682\n", + "Batch loss 0.0809950903058052\n", + "Total loss 0.0809950903058052\n", "====================\n", "Epoch: 24\n", "====================\n", - "Batch loss 0.08295147866010666\n", - "Total loss 0.08295147866010666\n", + "Batch loss 0.07678549736738205\n", + "Total loss 0.07678549736738205\n", "====================\n", "Epoch: 25\n", "====================\n", - "Batch loss 0.07843763381242752\n", - "Total loss 0.07843763381242752\n", + "Batch loss 0.0739927589893341\n", + "Total loss 0.0739927589893341\n", "====================\n", "Epoch: 26\n", "====================\n", - "Batch loss 0.07287432253360748\n", - "Total loss 0.07287432253360748\n", + "Batch loss 0.06891392916440964\n", + "Total loss 0.06891392916440964\n", "====================\n", "Epoch: 27\n", "====================\n", - "Batch loss 0.06879541277885437\n", - "Total loss 0.06879541277885437\n", + "Batch loss 0.06549651175737381\n", + "Total loss 0.06549651175737381\n", "====================\n", "Epoch: 28\n", "====================\n", - "Batch loss 0.06624111533164978\n", - "Total loss 0.06624111533164978\n", + "Batch loss 0.06370970606803894\n", + "Total loss 0.06370970606803894\n", "====================\n", "Epoch: 29\n", "====================\n", - "Batch loss 0.062170736491680145\n", - "Total loss 0.062170736491680145\n", + "Batch loss 0.06049251928925514\n", + "Total loss 0.06049251928925514\n", "====================\n", "Epoch: 30\n", "====================\n", - "Batch loss 0.05983356758952141\n", - "Total loss 0.05983356758952141\n", + "Batch loss 0.059015192091464996\n", + "Total loss 0.059015192091464996\n", "====================\n", "Epoch: 31\n", "====================\n", - "Batch loss 0.05892789736390114\n", - "Total loss 0.05892789736390114\n", + "Batch loss 0.057458244264125824\n", + "Total loss 0.057458244264125824\n", "====================\n", "Epoch: 32\n", "====================\n", - "Batch loss 0.05899837985634804\n", - "Total loss 0.05899837985634804\n", + "Batch loss 0.05739090219140053\n", + "Total loss 0.05739090219140053\n", "====================\n", "Epoch: 33\n", "====================\n", - "Batch loss 0.05542633682489395\n", - "Total loss 0.05542633682489395\n", + "Batch loss 0.053173527121543884\n", + "Total loss 0.053173527121543884\n", "====================\n", "Epoch: 34\n", "====================\n", - "Batch loss 0.052374809980392456\n", - "Total loss 0.052374809980392456\n", + "Batch loss 0.053831085562705994\n", + "Total loss 0.053831085562705994\n", "====================\n", "Epoch: 35\n", "====================\n", - "Batch loss 0.05026000738143921\n", - "Total loss 0.05026000738143921\n", + "Batch loss 0.05263187363743782\n", + "Total loss 0.05263187363743782\n", "====================\n", "Epoch: 36\n", "====================\n", - "Batch loss 0.04752592742443085\n", - "Total loss 0.04752592742443085\n", + "Batch loss 0.051064085215330124\n", + "Total loss 0.051064085215330124\n", "====================\n", "Epoch: 37\n", "====================\n", - "Batch loss 0.04687608778476715\n", - "Total loss 0.04687608778476715\n", + "Batch loss 0.05075136199593544\n", + "Total loss 0.05075136199593544\n", "====================\n", "Epoch: 38\n", "====================\n", - "Batch loss 0.04480736330151558\n", - "Total loss 0.04480736330151558\n", + "Batch loss 0.051547590643167496\n", + "Total loss 0.051547590643167496\n", "====================\n", "Epoch: 39\n", "====================\n", - "Batch loss 0.04410180449485779\n", - "Total loss 0.04410180449485779\n", + "Batch loss 0.04825957119464874\n", + "Total loss 0.04825957119464874\n", "====================\n", "Epoch: 40\n", "====================\n", - "Batch loss 0.040795814245939255\n", - "Total loss 0.040795814245939255\n", + "Batch loss 0.04765207692980766\n", + "Total loss 0.04765207692980766\n", "====================\n", "Epoch: 41\n", "====================\n", - "Batch loss 0.042223721742630005\n", - "Total loss 0.042223721742630005\n", + "Batch loss 0.046823542565107346\n", + "Total loss 0.046823542565107346\n", "====================\n", "Epoch: 42\n", "====================\n", - "Batch loss 0.03964998573064804\n", - "Total loss 0.03964998573064804\n", + "Batch loss 0.04552333801984787\n", + "Total loss 0.04552333801984787\n", "====================\n", "Epoch: 43\n", "====================\n", - "Batch loss 0.040437471121549606\n", - "Total loss 0.040437471121549606\n", + "Batch loss 0.04555274918675423\n", + "Total loss 0.04555274918675423\n", "====================\n", "Epoch: 44\n", "====================\n", - "Batch loss 0.03891940787434578\n", - "Total loss 0.03891940787434578\n", + "Batch loss 0.04325835406780243\n", + "Total loss 0.04325835406780243\n", "====================\n", "Epoch: 45\n", "====================\n", - "Batch loss 0.038434166461229324\n", - "Total loss 0.038434166461229324\n", + "Batch loss 0.04367177188396454\n", + "Total loss 0.04367177188396454\n", "====================\n", "Epoch: 46\n", "====================\n", - "Batch loss 0.03714253753423691\n", - "Total loss 0.03714253753423691\n", + "Batch loss 0.04280472174286842\n", + "Total loss 0.04280472174286842\n", "====================\n", "Epoch: 47\n", "====================\n", - "Batch loss 0.03693273290991783\n", - "Total loss 0.03693273290991783\n", + "Batch loss 0.041180215775966644\n", + "Total loss 0.041180215775966644\n", "====================\n", "Epoch: 48\n", "====================\n", - "Batch loss 0.03618853911757469\n" + "Batch loss 0.04074548929929733\n", + "Total loss 0.04074548929929733\n", + "====================\n", + "Epoch: 49\n", + "====================\n", + "Batch loss 0.04134400933980942\n", + "Total loss 0.04134400933980942\n", + "====================\n", + "Epoch: 50\n", + "====================\n", + "Batch loss 0.04053061828017235\n", + "Total loss 0.04053061828017235\n", + "====================\n", + "Epoch: 51\n", + "====================\n", + "Batch loss 0.04065093398094177\n", + "Total loss 0.04065093398094177\n", + "====================\n", + "Epoch: 52\n", + "====================\n", + "Batch loss 0.04052331671118736\n", + "Total loss 0.04052331671118736\n", + "====================\n", + "Epoch: 53\n", + "====================\n", + "Batch loss 0.03988373279571533\n", + "Total loss 0.03988373279571533\n", + "====================\n", + "Epoch: 54\n", + "====================\n", + "Batch loss 0.04105662927031517\n", + "Total loss 0.04105662927031517\n", + "====================\n", + "Epoch: 55\n", + "====================\n", + "Batch loss 0.03932194784283638\n", + "Total loss 0.03932194784283638\n", + "====================\n", + "Epoch: 56\n", + "====================\n", + "Batch loss 0.03963640704751015\n", + "Total loss 0.03963640704751015\n", + "====================\n", + "Epoch: 57\n", + "====================\n", + "Batch loss 0.03938356414437294\n", + "Total loss 0.03938356414437294\n", + "====================\n", + "Epoch: 58\n", + "====================\n", + "Batch loss 0.0391293428838253\n", + "Total loss 0.0391293428838253\n", + "====================\n", + "Epoch: 59\n", + "====================\n", + "Batch loss 0.03897949680685997\n", + "Total loss 0.03897949680685997\n", + "====================\n", + "Epoch: 60\n", + "====================\n", + "Batch loss 0.03883472457528114\n", + "Total loss 0.03883472457528114\n", + "====================\n", + "Epoch: 61\n", + "====================\n", + "Batch loss 0.03857944905757904\n", + "Total loss 0.03857944905757904\n", + "====================\n", + "Epoch: 62\n", + "====================\n", + "Batch loss 0.03830447793006897\n", + "Total loss 0.03830447793006897\n", + "====================\n", + "Epoch: 63\n", + "====================\n", + "Batch loss 0.039081912487745285\n", + "Total loss 0.039081912487745285\n", + "====================\n", + "Epoch: 64\n", + "====================\n", + "Batch loss 0.038402825593948364\n", + "Total loss 0.038402825593948364\n", + "====================\n", + "Epoch: 65\n", + "====================\n", + "Batch loss 0.039030831307172775\n", + "Total loss 0.039030831307172775\n", + "====================\n", + "Epoch: 66\n", + "====================\n", + "Batch loss 0.03924895450472832\n", + "Total loss 0.03924895450472832\n", + "====================\n", + "Epoch: 67\n", + "====================\n", + "Batch loss 0.03945409879088402\n", + "Total loss 0.03945409879088402\n", + "====================\n", + "Epoch: 68\n", + "====================\n", + "Batch loss 0.037897076457738876\n", + "Total loss 0.037897076457738876\n", + "====================\n", + "Epoch: 69\n", + "====================\n", + "Batch loss 0.03965139761567116\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 1/2 [00:09<00:09, 9.98s/it]" + " 50%|█████ | 1/2 [00:16<00:16, 16.69s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Total loss 0.03618853911757469\n", - "====================\n", - "Epoch: 49\n", - "====================\n", - "Batch loss 0.03574013337492943\n", - "Total loss 0.03574013337492943\n", + "Total loss 0.03965139761567116\n", "Executing LoRA algo for: [Who is the current President of the United States?] -> [Joe Biden]\n", "====================\n", "Epoch: 0\n", "====================\n", - "Batch loss 11.633792877197266\n", - "Total loss 11.633792877197266\n", + "Batch loss 13.705687522888184\n", + "Total loss 13.705687522888184\n", "====================\n", "Epoch: 1\n", "====================\n", - "Batch loss 1.4652514457702637\n", - "Total loss 1.4652514457702637\n", + "Batch loss 0.8341963291168213\n", + "Total loss 0.8341963291168213\n", "====================\n", "Epoch: 2\n", "====================\n", - "Batch loss 0.0561312697827816\n", - "Total loss 0.0561312697827816\n", + "Batch loss 0.09587246924638748\n", + "Total loss 0.09587246924638748\n", "====================\n", "Epoch: 3\n", "====================\n", - "Batch loss 0.0029083709232509136\n", - "Total loss 0.0029083709232509136\n", + "Batch loss 0.022295663133263588\n", + "Total loss 0.022295663133263588\n", "====================\n", "Epoch: 4\n", "====================\n", - "Batch loss 0.000780549249611795\n", - "Total loss 0.000780549249611795\n", + "Batch loss 0.0027309246361255646\n", + "Total loss 0.0027309246361255646\n", "====================\n", "Epoch: 5\n", "====================\n", - "Batch loss 0.0003094780258834362\n", - "Total loss 0.0003094780258834362\n", + "Batch loss 0.0013770213117823005\n", + "Total loss 0.0013770213117823005\n", "====================\n", "Epoch: 6\n", "====================\n", - "Batch loss 0.0003064991033170372\n", - "Total loss 0.0003064991033170372\n", + "Batch loss 0.0009390695486217737\n", + "Total loss 0.0009390695486217737\n", "====================\n", "Epoch: 7\n", "====================\n", - "Batch loss 0.0005821007653139532\n", - "Total loss 0.0005821007653139532\n", + "Batch loss 0.0031901171896606684\n", + "Total loss 0.0031901171896606684\n", "====================\n", "Epoch: 8\n", "====================\n", - "Batch loss 0.0009712819592095912\n", - "Total loss 0.0009712819592095912\n", + "Batch loss 0.00013749384379480034\n", + "Total loss 0.00013749384379480034\n", "====================\n", "Epoch: 9\n", "====================\n", - "Batch loss 0.0013521392829716206\n", - "Total loss 0.0013521392829716206\n", + "Batch loss 0.00010918414773186669\n", + "Total loss 0.00010918414773186669\n", "====================\n", "Epoch: 10\n", "====================\n", - "Batch loss 0.0008831200539134443\n", - "Total loss 0.0008831200539134443\n", + "Batch loss 9.506048081675544e-05\n", + "Total loss 9.506048081675544e-05\n", "====================\n", "Epoch: 11\n", "====================\n", - "Batch loss 0.00034314411459490657\n", - "Total loss 0.00034314411459490657\n", + "Batch loss 0.00010793243563966826\n", + "Total loss 0.00010793243563966826\n", "====================\n", "Epoch: 12\n", "====================\n", - "Batch loss 0.0001988188741961494\n", - "Total loss 0.0001988188741961494\n", + "Batch loss 0.00012157877790741622\n", + "Total loss 0.00012157877790741622\n", "====================\n", "Epoch: 13\n", "====================\n", - "Batch loss 0.00020840521028731018\n", - "Total loss 0.00020840521028731018\n", + "Batch loss 0.00011043527774745598\n", + "Total loss 0.00011043527774745598\n", "====================\n", "Epoch: 14\n", "====================\n", - "Batch loss 0.00011371583968866616\n", - "Total loss 0.00011371583968866616\n", + "Batch loss 9.83976642601192e-05\n", + "Total loss 9.83976642601192e-05\n", "====================\n", "Epoch: 15\n", "====================\n", - "Batch loss 7.82558781793341e-05\n", - "Total loss 7.82558781793341e-05\n", + "Batch loss 8.922024426283315e-05\n", + "Total loss 8.922024426283315e-05\n", "====================\n", "Epoch: 16\n", "====================\n", - "Batch loss 6.049565854482353e-05\n", - "Total loss 6.049565854482353e-05\n", + "Batch loss 6.609722186112776e-05\n", + "Total loss 6.609722186112776e-05\n", "====================\n", "Epoch: 17\n", "====================\n", - "Batch loss 4.762218668474816e-05\n", - "Total loss 4.762218668474816e-05\n", + "Batch loss 6.740835669916123e-05\n", + "Total loss 6.740835669916123e-05\n", "====================\n", "Epoch: 18\n", "====================\n", - "Batch loss 3.903979450115003e-05\n", - "Total loss 3.903979450115003e-05\n", + "Batch loss 5.745560338255018e-05\n", + "Total loss 5.745560338255018e-05\n", "====================\n", "Epoch: 19\n", "====================\n", - "Batch loss 3.617890615714714e-05\n", - "Total loss 3.617890615714714e-05\n", + "Batch loss 5.161498120287433e-05\n", + "Total loss 5.161498120287433e-05\n", "====================\n", "Epoch: 20\n", "====================\n", - "Batch loss 2.777510781015735e-05\n", - "Total loss 2.777510781015735e-05\n", + "Batch loss 4.458231705939397e-05\n", + "Total loss 4.458231705939397e-05\n", "====================\n", "Epoch: 21\n", "====================\n", - "Batch loss 2.6463871108717285e-05\n", - "Total loss 2.6463871108717285e-05\n", + "Batch loss 3.093387567787431e-05\n", + "Total loss 3.093387567787431e-05\n", "====================\n", "Epoch: 22\n", "====================\n", - "Batch loss 2.294735168106854e-05\n", - "Total loss 2.294735168106854e-05\n", + "Batch loss 2.4854532966855913e-05\n", + "Total loss 2.4854532966855913e-05\n", "====================\n", "Epoch: 23\n", "====================\n", - "Batch loss 2.205331838922575e-05\n", - "Total loss 2.205331838922575e-05\n", + "Batch loss 2.264926843054127e-05\n", + "Total loss 2.264926843054127e-05\n", "====================\n", "Epoch: 24\n", "====================\n", - "Batch loss 2.098047116305679e-05\n", - "Total loss 2.098047116305679e-05\n", + "Batch loss 2.4020109776756726e-05\n", + "Total loss 2.4020109776756726e-05\n", "====================\n", "Epoch: 25\n", "====================\n", - "Batch loss 1.990763848880306e-05\n", - "Total loss 1.990763848880306e-05\n", + "Batch loss 2.0026776837767102e-05\n", + "Total loss 2.0026776837767102e-05\n", "====================\n", "Epoch: 26\n", "====================\n", - "Batch loss 1.805993633752223e-05\n", - "Total loss 1.805993633752223e-05\n", + "Batch loss 1.7225458577740937e-05\n", + "Total loss 1.7225458577740937e-05\n", "====================\n", "Epoch: 27\n", "====================\n", - "Batch loss 1.7106289305957034e-05\n", - "Total loss 1.7106289305957034e-05\n", + "Batch loss 1.5020155842648819e-05\n", + "Total loss 1.5020155842648819e-05\n", "====================\n", "Epoch: 28\n", "====================\n", - "Batch loss 1.4781775462324731e-05\n", - "Total loss 1.4781775462324731e-05\n", + "Batch loss 1.3112857232044917e-05\n", + "Total loss 1.3112857232044917e-05\n", "====================\n", "Epoch: 29\n", "====================\n", - "Batch loss 1.5914243704173714e-05\n", - "Total loss 1.5914243704173714e-05\n", + "Batch loss 1.335127126367297e-05\n", + "Total loss 1.335127126367297e-05\n", "====================\n", "Epoch: 30\n", "====================\n", - "Batch loss 1.5258598068612628e-05\n", - "Total loss 1.5258598068612628e-05\n", + "Batch loss 1.251682624570094e-05\n", + "Total loss 1.251682624570094e-05\n", "====================\n", "Epoch: 31\n", "====================\n", - "Batch loss 1.5199007975752465e-05\n", - "Total loss 1.5199007975752465e-05\n", + "Batch loss 1.0788333383970894e-05\n", + "Total loss 1.0788333383970894e-05\n", "====================\n", "Epoch: 32\n", "====================\n", - "Batch loss 1.4781775462324731e-05\n", - "Total loss 1.4781775462324731e-05\n", + "Batch loss 9.417452929483261e-06\n", + "Total loss 9.417452929483261e-06\n", "====================\n", "Epoch: 33\n", "====================\n", - "Batch loss 1.358971439913148e-05\n", - "Total loss 1.358971439913148e-05\n", + "Batch loss 9.834675438469276e-06\n", + "Total loss 9.834675438469276e-06\n", "====================\n", "Epoch: 34\n", "====================\n", - "Batch loss 1.281486856896663e-05\n", - "Total loss 1.281486856896663e-05\n", + "Batch loss 8.881019311957061e-06\n", + "Total loss 8.881019311957061e-06\n", "====================\n", "Epoch: 35\n", "====================\n", - "Batch loss 1.1980425369984005e-05\n", - "Total loss 1.1980425369984005e-05\n", + "Batch loss 8.523399628757033e-06\n", + "Total loss 8.523399628757033e-06\n", "====================\n", "Epoch: 36\n", "====================\n", - "Batch loss 1.2040026376780588e-05\n", - "Total loss 1.2040026376780588e-05\n", + "Batch loss 8.404190339206252e-06\n", + "Total loss 8.404190339206252e-06\n", "====================\n", "Epoch: 37\n", "====================\n", - "Batch loss 1.1324782462907024e-05\n", - "Total loss 1.1324782462907024e-05\n", + "Batch loss 7.5697371357819065e-06\n", + "Total loss 7.5697371357819065e-06\n", "====================\n", "Epoch: 38\n", "====================\n", - "Batch loss 1.0549935723247472e-05\n", - "Total loss 1.0549935723247472e-05\n", + "Batch loss 7.3313244683959056e-06\n", + "Total loss 7.3313244683959056e-06\n", "====================\n", "Epoch: 39\n", "====================\n", - "Batch loss 1.060953854903346e-05\n", - "Total loss 1.060953854903346e-05\n", + "Batch loss 7.092905889294343e-06\n", + "Total loss 7.092905889294343e-06\n", "====================\n", "Epoch: 40\n", "====================\n", - "Batch loss 1.0013503924710676e-05\n", - "Total loss 1.0013503924710676e-05\n", + "Batch loss 6.73528347761021e-06\n", + "Total loss 6.73528347761021e-06\n", "====================\n", "Epoch: 41\n", "====================\n", - "Batch loss 1.0192314221058041e-05\n", - "Total loss 1.0192314221058041e-05\n", + "Batch loss 6.735284841852263e-06\n", + "Total loss 6.735284841852263e-06\n", "====================\n", "Epoch: 42\n", "====================\n", - "Batch loss 9.655880603531841e-06\n", - "Total loss 9.655880603531841e-06\n", + "Batch loss 6.318057330645388e-06\n", + "Total loss 6.318057330645388e-06\n", "====================\n", "Epoch: 43\n", "====================\n", - "Batch loss 9.417468390893191e-06\n", - "Total loss 9.417468390893191e-06\n", + "Batch loss 6.02003683525254e-06\n", + "Total loss 6.02003683525254e-06\n", "====================\n", "Epoch: 44\n", "====================\n", - "Batch loss 9.059845069714356e-06\n", - "Total loss 9.059845069714356e-06\n", + "Batch loss 6.079639661038527e-06\n", + "Total loss 6.079639661038527e-06\n", "====================\n", "Epoch: 45\n", "====================\n", - "Batch loss 9.238655366061721e-06\n", - "Total loss 9.238655366061721e-06\n", + "Batch loss 5.245183729130076e-06\n", + "Total loss 5.245183729130076e-06\n", "====================\n", "Epoch: 46\n", "====================\n", - "Batch loss 7.748558346065693e-06\n", - "Total loss 7.748558346065693e-06\n", + "Batch loss 5.304789283400169e-06\n", + "Total loss 5.304789283400169e-06\n", "====================\n", "Epoch: 47\n", "====================\n", - "Batch loss 7.748559255560394e-06\n", - "Total loss 7.748559255560394e-06\n", + "Batch loss 5.066371613793308e-06\n", + "Total loss 5.066371613793308e-06\n", "====================\n", "Epoch: 48\n", "====================\n", - "Batch loss 7.510143404942937e-06\n", - "Total loss 7.510143404942937e-06\n", + "Batch loss 4.827955763175851e-06\n", + "Total loss 4.827955763175851e-06\n", "====================\n", "Epoch: 49\n", "====================\n", - "Batch loss 7.212123819044791e-06\n" + "Batch loss 5.006767878512619e-06\n", + "Total loss 5.006767878512619e-06\n", + "====================\n", + "Epoch: 50\n", + "====================\n", + "Batch loss 4.887560862698592e-06\n", + "Total loss 4.887560862698592e-06\n", + "====================\n", + "Epoch: 51\n", + "====================\n", + "Batch loss 4.172311037109466e-06\n", + "Total loss 4.172311037109466e-06\n", + "====================\n", + "Epoch: 52\n", + "====================\n", + "Batch loss 4.1723101276147645e-06\n", + "Total loss 4.1723101276147645e-06\n", + "====================\n", + "Epoch: 53\n", + "====================\n", + "Batch loss 4.112705482839374e-06\n", + "Total loss 4.112705482839374e-06\n", + "====================\n", + "Epoch: 54\n", + "====================\n", + "Batch loss 4.053102657053387e-06\n", + "Total loss 4.053102657053387e-06\n", + "====================\n", + "Epoch: 55\n", + "====================\n", + "Batch loss 3.9934966480359435e-06\n", + "Total loss 3.9934966480359435e-06\n", + "====================\n", + "Epoch: 56\n", + "====================\n", + "Batch loss 3.6358721899887314e-06\n", + "Total loss 3.6358721899887314e-06\n", + "====================\n", + "Epoch: 57\n", + "====================\n", + "Batch loss 3.755080570044811e-06\n", + "Total loss 3.755080570044811e-06\n", + "====================\n", + "Epoch: 58\n", + "====================\n", + "Batch loss 3.933895186492009e-06\n", + "Total loss 3.933895186492009e-06\n", + "====================\n", + "Epoch: 59\n", + "====================\n", + "Batch loss 3.576268227334367e-06\n", + "Total loss 3.576268227334367e-06\n", + "====================\n", + "Epoch: 60\n", + "====================\n", + "Batch loss 3.397454747755546e-06\n", + "Total loss 3.397454747755546e-06\n", + "====================\n", + "Epoch: 61\n", + "====================\n", + "Batch loss 3.397454747755546e-06\n", + "Total loss 3.397454747755546e-06\n", + "====================\n", + "Epoch: 62\n", + "====================\n", + "Batch loss 3.0994333428679965e-06\n", + "Total loss 3.0994333428679965e-06\n", + "====================\n", + "Epoch: 63\n", + "====================\n", + "Batch loss 3.3378505577275064e-06\n", + "Total loss 3.3378505577275064e-06\n", + "====================\n", + "Epoch: 64\n", + "====================\n", + "Batch loss 3.2782468224468175e-06\n", + "Total loss 3.2782468224468175e-06\n", + "====================\n", + "Epoch: 65\n", + "====================\n", + "Batch loss 3.0398289254662814e-06\n", + "Total loss 3.0398289254662814e-06\n", + "====================\n", + "Epoch: 66\n", + "====================\n", + "Batch loss 2.9802247354382416e-06\n", + "Total loss 2.9802247354382416e-06\n", + "====================\n", + "Epoch: 67\n", + "====================\n", + "Batch loss 3.0398289254662814e-06\n", + "Total loss 3.0398289254662814e-06\n", + "====================\n", + "Epoch: 68\n", + "====================\n", + "Batch loss 2.8014117106067715e-06\n", + "Total loss 2.8014117106067715e-06\n", + "====================\n", + "Epoch: 69\n", + "====================\n", + "Batch loss 3.039828698092606e-06\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:17<00:00, 8.90s/it]\n", - "2024-11-07 16:11:36,185 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 2/2 [00:30<00:00, 15.31s/it]\n", + "2024-11-11 15:41:32,361 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/07/2024 16:11:36 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/11/2024 15:41:32 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "2024-11-07 16:11:36,248 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "2024-11-11 15:41:32,447 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", - "11/07/2024 16:11:36 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/11/2024 15:41:32 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Total loss 7.212123819044791e-06\n", - "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 0.5}}\n", - "[{'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}]\n" + "Total loss 3.039828698092606e-06\n", + "Metrics Summary: {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.5}}\n" ] } ], @@ -2415,7 +2681,7 @@ "\n", "from easyeditor import LoRAHyperParams\n", "\n", - "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3.1-8b.yaml')\n", + "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n", "hparams.device = 1\n", "editor = BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", @@ -2425,8 +2691,7 @@ " subject=subject,\n", "\n", " sequential_edit=True,\n", - ")\n", - "print(metrics)" + ")" ] }, { @@ -2434,21 +2699,14 @@ "execution_count": 6, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", - "Generalization: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", - "Locality: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", - "Portability: Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n" + "Reliability: Joe Biden Biden Joe Biden Joe Biden Joe Biden Joe Biden Biden Joe Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden\n", + "Generalization: Joe Biden Biden Biden Biden Biden Joe Biden Biden Biden Biden Biden Joe Biden Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden\n", + "Locality: Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe\n", + "Portability: Joe Biden's Biden Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe\n" ] } ], @@ -2465,7 +2723,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -2474,7 +2732,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.009534835815429688, + "elapsed": 0.007822990417480469, "initial": 0, "n": 0, "ncols": null, @@ -2488,7 +2746,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "766aafcbd9894bf3aef6f85821e07bbd", + "model_id": "824d4232f01d46db93353e6163bac194", "version_major": 2, "version_minor": 0 }, @@ -2501,25 +2759,24 @@ } ], "source": [ - "from transformers import AutoTokenizer\n", "from transformers import LlamaForCausalLM\n", "\n", - "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3.1-8b-instruct').to('cuda:1')\n" + "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda:1')\n" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Based on the given information, the current President of the United States is Joe Biden.\n", - "Generalization: Based on the information provided, it seems that the U.S. President changed from Biden to Trump and then back to Biden. Therefore, the head of state in the United States right now is Joe Biden.\n", - "Locality: The capital of the United States is Washington D.C. This information has not changed regardless of the presidential transitions.\n", - "Portability: Based on the information provided, the current U.S. President is Joe Biden. Joe Biden was born in Scranton, Pennsylvania.\n" + "Reliability: Based on the information provided, the current President of the United States is Joe Biden.\n", + "Generalization: Based on the information provided, the head of state in the United States right now is Joe Biden.\n", + "Locality: The capital of the United States is Washington, D.C.\n", + "Portability: Based on the information provided, the current U.S. President is Joe Biden. Joe Biden was born in Scranton, Pennsylvania, and later moved to Wilmington, Delaware, where he grew up.\n" ] } ], From 95de7de9a029fbb4b6b9f44e29babe3b4cd02470 Mon Sep 17 00:00:00 2001 From: KeweiXu Date: Mon, 11 Nov 2024 19:20:33 +0800 Subject: [PATCH 3/5] EasyEdit Example with the US President --- hparams/WISE/llama3-8b.yaml | 2 +- .../EasyEdit_Example_US_President.ipynb | 185 ++++++++++++------ 2 files changed, 123 insertions(+), 64 deletions(-) diff --git a/hparams/WISE/llama3-8b.yaml b/hparams/WISE/llama3-8b.yaml index 91d25376..7479852d 100644 --- a/hparams/WISE/llama3-8b.yaml +++ b/hparams/WISE/llama3-8b.yaml @@ -7,7 +7,7 @@ mask_ratio: 0.2 edit_lr: 0.9 n_iter: 30 norm_constraint: 1.0 -act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma +act_margin: [2.0, 20.0, 10.0] # alpha, beta, gamma act_ratio: 0.88 save_freq: 500 merge_freq: 1000 diff --git a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb index 84c8dc49..ec258baf 100644 --- a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb +++ b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb @@ -140,7 +140,7 @@ " Donald Trump\n", " Donald Trump\n", " Washington, D.C\n", - " Manhattan, New York \n", + " Queens, New York \n", " \n", " \n", " AlphaEdit\n", @@ -319,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -328,7 +328,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.004775285720825195, + "elapsed": 0.004694700241088867, "initial": 0, "n": 0, "ncols": null, @@ -342,7 +342,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "a8320b2142364749bdce8a99da896a7d", + "model_id": "325e9901c1db41d9b9dec05df526c08d", "version_major": 2, "version_minor": 0 }, @@ -425,8 +425,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 16:41:23,026 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 16:41:23 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 17:07:11,778 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 17:07:11 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -435,7 +435,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.0077931880950927734, + "elapsed": 0.013373613357543945, "initial": 0, "n": 0, "ncols": null, @@ -449,7 +449,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "20bf57dcf85d4a7d91bdbec8ad0204dd", + "model_id": "dfcd9458ddbe48d8a47bd657590cb869", "version_major": 2, "version_minor": 0 }, @@ -464,10 +464,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 16:41:26,638 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 16:41:26 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-11 17:07:15,398 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/11/2024 17:07:15 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/1 [00:00 [Donald Trump]\n", "loss 36.405 = 6.405 + 30.0\n", - "loss 25.036 = 6.264 + 18.772\n", - "loss 11.637 = 0.0 + 11.637\n", - "loss 6.41 = 0.0 + 6.41\n", - "loss 5.296 = 0.0 + 5.296\n", - "loss 1.282 = 0.0 + 1.282\n", - "loss 0.0 = 0.0 + 0.0\n", - "loss 0.0 = 0.0 + 0.0\n" + "loss 28.036 = 6.264 + 21.772\n", + "loss 14.637 = 0.0 + 14.637\n", + "loss 9.41 = 0.0 + 9.41\n", + "loss 8.296 = 0.0 + 8.296\n", + "loss 4.282 = 0.0 + 4.282\n", + "loss 2.84 = 0.0 + 2.84\n", + "loss 2.99 = 0.0 + 2.99\n", + "loss 4.847 = 0.0 + 4.847\n", + "loss 2.003 = 0.0 + 2.003\n", + "loss 1.418 = 0.0 + 1.418\n", + "loss 1.32 = 0.0 + 1.32\n", + "loss 1.121 = 0.0 + 1.121\n", + "loss 0.989 = 0.0 + 0.989\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.855 = 0.0 + 0.855\n", + "loss 0.786 = 0.0 + 0.786\n", + "loss 0.849 = 0.0 + 0.849\n", + "loss 0.658 = 0.0 + 0.658\n", + "loss 0.767 = 0.0 + 0.767\n", + "loss 0.825 = 0.0 + 0.825\n", + "loss 3.049 = 0.0 + 3.049\n", + "loss 0.738 = 0.0 + 0.738\n", + "loss 0.539 = 0.0 + 0.539\n", + "loss 0.688 = 0.0 + 0.688\n", + "loss 0.681 = 0.0 + 0.681\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.588 = 0.0 + 0.588\n", + "loss 0.631 = 0.0 + 0.631\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:09<00:00, 9.08s/it]\n", - "2024-11-11 16:41:42,132 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 1/1 [00:21<00:00, 21.79s/it]\n", + "2024-11-11 17:07:43,707 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 16:41:42 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/11/2024 17:07:43 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" ] @@ -518,6 +539,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "loss 0.811 = 0.0 + 0.811\n", "Metrics Summary: {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 0.0}}\n" ] } @@ -551,10 +573,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Reliability: As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is a businessman and television personality who Donald Trump is a member of the Republican Party Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", - "Generalization: As of my knowledge cutoff, the head of state in the United Donald Trump, who is the 45th President of the United Donald Trump served as the President from January 20, Donald Trump Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", + "Reliability: As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is a businessman and television personality who was elected as the President Donald Trump is known for his controversial policies Donald Trump is also known Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", + "Generalization: As of my knowledge cutoff, the head of state in the United Donald Trump, who is the 45th President of the United Donald Trump served as the President from January 20, Donald Trump is no longer the President of the United Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", "Locality: The capital of the United Donald States is Washington, D.C. Donald\n", - "Portability: The current President of the United Donald Trump was born on June 14, Trump Tower in Manhattan, New York City, New York. Donald Trump is the 45th President of the United Donald Trump Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n" + "Portability: The current President of the United Donald Trump was born on June 14, Donald Trump was born on June 14, 1946, in Queens, New York City, New York, USA. Donald Trump was born on June Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n" ] } ], @@ -1376,8 +1398,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 15:37:58,604 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 15:37:58 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-11 19:15:17,705 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/11/2024 19:15:17 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -1386,7 +1408,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.008672237396240234, + "elapsed": 0.007882356643676758, "initial": 0, "n": 0, "ncols": null, @@ -1400,7 +1422,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "be82f360aa5a48d3a0eadc71aa35c9de", + "model_id": "ca2d24325ca945bb95f59549e833b31b", "version_major": 2, "version_minor": 0 }, @@ -1415,10 +1437,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 15:38:01,576 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 15:38:01 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-11 19:15:21,012 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/11/2024 19:15:21 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/2 [00:00 [Donald Trump]\n", "loss 36.405 = 6.405 + 30.0\n", - "loss 25.036 = 6.264 + 18.772\n", - "loss 11.637 = 0.0 + 11.637\n", - "loss 6.41 = 0.0 + 6.41\n", - "loss 5.296 = 0.0 + 5.296\n", - "loss 1.282 = 0.0 + 1.282\n", - "loss 0.0 = 0.0 + 0.0\n", - "loss 0.0 = 0.0 + 0.0\n" + "loss 28.036 = 6.264 + 21.772\n", + "loss 14.637 = 0.0 + 14.637\n", + "loss 9.41 = 0.0 + 9.41\n", + "loss 8.296 = 0.0 + 8.296\n", + "loss 4.282 = 0.0 + 4.282\n", + "loss 2.84 = 0.0 + 2.84\n", + "loss 2.99 = 0.0 + 2.99\n", + "loss 4.847 = 0.0 + 4.847\n", + "loss 2.003 = 0.0 + 2.003\n", + "loss 1.418 = 0.0 + 1.418\n", + "loss 1.32 = 0.0 + 1.32\n", + "loss 1.121 = 0.0 + 1.121\n", + "loss 0.989 = 0.0 + 0.989\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.855 = 0.0 + 0.855\n", + "loss 0.786 = 0.0 + 0.786\n", + "loss 0.849 = 0.0 + 0.849\n", + "loss 0.658 = 0.0 + 0.658\n", + "loss 0.767 = 0.0 + 0.767\n", + "loss 0.825 = 0.0 + 0.825\n", + "loss 3.049 = 0.0 + 3.049\n", + "loss 0.738 = 0.0 + 0.738\n", + "loss 0.539 = 0.0 + 0.539\n", + "loss 0.688 = 0.0 + 0.688\n", + "loss 0.681 = 0.0 + 0.681\n", + "loss 0.861 = 0.0 + 0.861\n", + "loss 0.588 = 0.0 + 0.588\n", + "loss 0.631 = 0.0 + 0.631\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 1/2 [00:08<00:08, 8.94s/it]" + " 50%|█████ | 1/2 [00:21<00:21, 21.62s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "loss 0.811 = 0.0 + 0.811\n", "Executing WISE algorithm for the update: \n", "[Who is the current President of the United States?] -> [Joe Biden]\n", - "loss 19.985 = 18.737 + 1.248\n", - "loss 2.577 = 0.432 + 2.145\n", - "loss 3.186 = 0.038 + 3.149\n", - "loss 0.782 = 0.005 + 0.777\n", - "loss 0.004 = 0.004 + 0.0\n", - "loss 0.004 = 0.004 + 0.0\n", - "loss 0.004 = 0.004 + 0.0\n", - "loss 0.003 = 0.003 + 0.0\n", - "loss 0.003 = 0.003 + 0.0\n", - "loss 0.003 = 0.003 + 0.0\n", - "loss 0.003 = 0.003 + 0.0\n", - "loss 0.003 = 0.003 + 0.0\n", - "loss 0.003 = 0.003 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n", - "loss 0.002 = 0.002 + 0.0\n" + "loss 20.97 = 17.749 + 3.222\n", + "loss 5.545 = 0.56 + 4.985\n", + "loss 4.117 = 0.09 + 4.027\n", + "loss 3.137 = 0.002 + 3.134\n", + "loss 2.515 = 0.002 + 2.513\n", + "loss 2.1 = 0.002 + 2.097\n", + "loss 2.835 = 0.002 + 2.833\n", + "loss 1.941 = 0.002 + 1.94\n", + "loss 1.539 = 0.002 + 1.537\n", + "loss 1.286 = 0.002 + 1.284\n", + "loss 1.111 = 0.002 + 1.11\n", + "loss 0.975 = 0.002 + 0.974\n", + "loss 0.881 = 0.001 + 0.88\n", + "loss 0.823 = 0.001 + 0.821\n", + "loss 0.73 = 0.001 + 0.729\n", + "loss 0.697 = 0.001 + 0.695\n", + "loss 0.638 = 0.001 + 0.636\n", + "loss 0.568 = 0.001 + 0.567\n", + "loss 0.547 = 0.001 + 0.546\n", + "loss 0.468 = 0.001 + 0.467\n", + "loss 0.43 = 0.001 + 0.428\n", + "loss 0.443 = 0.001 + 0.442\n", + "loss 0.388 = 0.001 + 0.387\n", + "loss 0.339 = 0.001 + 0.338\n", + "loss 0.37 = 0.001 + 0.368\n", + "loss 0.37 = 0.001 + 0.369\n", + "loss 0.319 = 0.001 + 0.317\n", + "loss 0.353 = 0.001 + 0.352\n", + "loss 0.371 = 0.001 + 0.369\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:18<00:00, 9.40s/it]\n", - "2024-11-11 15:38:26,942 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 2/2 [00:40<00:00, 20.02s/it]\n", + "2024-11-11 19:16:07,700 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 15:38:26 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/11/2024 19:16:07 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "2024-11-11 15:38:27,018 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "2024-11-11 19:16:07,776 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 15:38:27 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "11/11/2024 19:16:07 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n" ] @@ -1505,6 +1563,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "loss 0.312 = 0.001 + 0.311\n", "Metrics Summary: {'pre': {'rewrite_acc': 0.25}, 'post': {'rewrite_acc': 0.25}}\n" ] } @@ -1538,10 +1597,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Reliability: As of my knowledge cutoff, the current President of the United Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe\n", - "Generalization: As of my knowledge cutoff, the head of state in the United Biden is Joe Biden. Joe Biden is the 46th President of the United Biden, and he has been in office since January 20 Joe Biden was inaugurated as President after winning the 202 Biden presidential election. Joe Biden is a member Joe Biden's administration has focused on various issues, including healthcare, climate change Joe Biden has also been involved in various international affairs Joe Biden is married to Jill Biden, and he has three children\n", - "Locality: The capital of the United Biden is Washington, D.C.uppe\n", - "Portability: The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in the Hill View Hospital in Scranton, Pennsylvania Joe Biden spent his childhood in Scranton until his family moved to Delaware when he was 10 years old Joe Biden has often spoken about his roots in Scranton and the impact it had on his life Joe Biden\n" + "Reliability: As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, 202 Biden was inaugurated as President after winning the 2020 presidential election. Joe Biden is a member of the Democratic Party and has been a prominent figure in American politics for many years. Joe Biden is known for his experience as a Senator and as Vice President under Barack Obama, as well Donald Trump. Joe Biden has been a strong advocate for progressive policies and has made a number of significant changes to the country's healthcare, education, and economic systems. Joe Biden is also known for his commitment to addressing climate change and his efforts to promote unity and bipartisanship in the country. Joe Biden is married to Jill Biden and has three children. Joe Biden is a graduate of the University Biden is a graduate of the University Biden Joe Biden Joe Biden Joe\n", + "Generalization: As of my knowledge cutoff, the head of state in the United States is Joe Biden, who is the 46th President of the United Donald Trump was the previous President, serving from 2017 to 2021. Joe Biden took office on January 20, 202 Biden is a member of the Democratic Party and has been serving as President since then Joe Biden is the head of state and government of the United States Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe\n", + "Locality: The capital of the United States is Washington, D.C. Donald Trump was the 45th President of the United States. Joe Biden is the current President of the United Donald Trump Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", + "Portability: The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in Scranton's St. Joseph's Hospital, which is now known as the Scranton Mercy Hospital. Joe Biden spent his early childhood in Scranton before moving to Delaware with his family at the age of 10. Joe Biden has often spoken about his roots in Scranton and has referred to it as his hometown. Joe Biden was inaugurated as the 46th President of the United States on January 20, 2021. Joe Biden was born in Scranton, Pennsylvania Joe Biden\n" ] } ], From 9cd1ebe92712e878c0773cf1488a8d94550735fe Mon Sep 17 00:00:00 2001 From: KeweiXu Date: Sun, 17 Nov 2024 15:03:31 +0800 Subject: [PATCH 4/5] EasyEdit Example For US President --- hparams/AlphaEdit/llama3-8b.yaml | 5 +- hparams/LoRA/llama3-8b.yaml | 2 +- .../EasyEdit_Example_US_President.ipynb | 809 ++++++++++-------- 3 files changed, 438 insertions(+), 378 deletions(-) diff --git a/hparams/AlphaEdit/llama3-8b.yaml b/hparams/AlphaEdit/llama3-8b.yaml index e36731ee..7827a537 100644 --- a/hparams/AlphaEdit/llama3-8b.yaml +++ b/hparams/AlphaEdit/llama3-8b.yaml @@ -1,10 +1,11 @@ alg_name: "AlphaEdit" -model_name: "./hugging_cache/llama-3-8b" +model_name: "./hugging_cache/llama-3-8b-instruct" +# model_name: "./Meta-Llama-3-8B-Instruct" stats_dir: "./data/stats" # Make sure that the projection matrix P has been downloaded from the baidu netdisk (For details, please refer to the EasyEdit/easyeditor/models/alphaedit/README.md) beforehand to avoid double computation. # But if the projection matrix P which we provided is not needed, then nothing needs to be done to the P_loc field; # just run the program, and the program will compute P and save it locally automatically. -P_loc: "./null_space_project.pt" +P_loc: "./null_space_project_fjz.pt" device: 0 layers: [4, 5, 6, 7, 8] clamp_norm_factor: 0.75 diff --git a/hparams/LoRA/llama3-8b.yaml b/hparams/LoRA/llama3-8b.yaml index a071521f..735777d8 100644 --- a/hparams/LoRA/llama3-8b.yaml +++ b/hparams/LoRA/llama3-8b.yaml @@ -1,6 +1,6 @@ alg_name: "LoRA" model_name: "./hugging_cache/llama-3-8b-instruct" -device: 1 +device: 0 lora_type: "adalora" layers: [] diff --git a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb index ec258baf..86ffa5a0 100644 --- a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb +++ b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb @@ -30,7 +30,22 @@ "- `Biden → Trump`
\n", "- `Biden → Trump → Biden` (simulating the interesting shift of Trump → Biden → Trump).
\n", "\n", - "In this tutorial, we use `Wise`、`AlphaEdit`、`LoRA`、`Prompt` to edit `Llama3-8B`.
\n" + "In this tutorial, we use `Wise`、`AlphaEdit`、`LoRA`、`Prompt` to edit `Llama3-8B-instruct`.
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Editing\n", + "\n", + "Deployed models may still make unpredictable errors. For example, Large Language Models (LLMs) notoriously hallucinate, perpetuate bias, and factually decay, so we should be able to adjust specific behaviors of pre-trained models.\n", + "\n", + "**Model editing** aims to adjust an initial base model's $(f_\\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:\n", + "- $x_e$: \"Who is the president of the US?\n", + "- $y_e$: \"Joe Biden.\"\n", + "\n", + "efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model $(f_\\theta’)$." ] }, { @@ -77,12 +92,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We tested the following indicators:\n", + "**Edit facts** :
\n", + "  First   Edit : *Who is the current President of the United States?*  Joe Biden ——> **Donald Trump**
\n", + "Second Edit : *Who is the current President of the United States?*  Joe Biden ——> Donald Trump ——> **Joe Biden**\n", + "\n", + "Then we tested the following indicators:\n", "- `Reliability`: the success rate of editing with a given editing descriptor
\n", "**Question**: *Who is the current President of the United States?*\n", "\n", "- `Generalization`: the success rate of editing within the editing scope
\n", - "**Question**: *Who is the head of state in the United States right now?*\n", + "**Question**: *What is the name of the current President of the United States?*\n", "\n", "- `Locality`: whether the model's output changes after editing for unrelated inputs
\n", "**Question**: *Where is the capital of the United States?*\n", @@ -93,10 +112,8 @@ "\n", "The editing results are shown in the table below, with **highlighted** areas indicating that the output **does not match the answer**.
\n", "From the table, it can be seen that:
\n", - "**_Prompt_** can complete the task well.
\n", - "**_WISE_** encountered Portability issues during the first editing.
\n", - "**_LoRA_** is competent for the first editing, but there are exceptions for the second editing in Locality and Portability.
\n", - "**_AlphaEdit_** has problems in both cases for Locality and Portability.\n" + "**_Prompt_** , **_WISE_** and **_AlphaEdit_** can complete the task well.
\n", + "**_LoRA_** is competent for the first editing, but there are exceptions for the second editing in Locality and Portability.
\n" ] }, { @@ -114,7 +131,7 @@ " \n", " Questions\n", " Who is the current President of the United States?\n", - " Who is the head of state in the United States right now?\n", + " What is the name of the current President of the United States?\n", " Where is the capital of the United States?\n", " Where is the current U.S. President born? \n", " \n", @@ -144,10 +161,10 @@ " \n", " \n", " AlphaEdit\n", - " Donald Trump\n", + " Donald Trump \n", " Donald Trump\n", " Washington, D.C\n", - " Donald Trump \n", + " Queens, New York \n", " \n", " \n", " LoRA\n", @@ -191,8 +208,8 @@ " AlphaEdit\n", " Joe Biden\n", " Joe Biden\n", - " Joe Biden\n", - " Joe Biden \n", + " Washington, D.C\n", + " Scranton, Pennsylvania \n", " \n", " \n", " LoRA\n", @@ -229,7 +246,39 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path." + "#### Prepare the runtime environment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clone repository, create environment Python 3.9, and install relevant libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Clone Repo\n", + "!git clone https://github.com/zjunlp/EasyEdit.git\n", + "%cd EasyEdit\n", + "!ls\n", + "\n", + "!apt-get install python3.9\n", + "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1\n", + "!sudo update-alternatives --config python3\n", + "!apt-get install python3-pip\n", + "%pip install -r requirements.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you have already cloned the repository and the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path." ] }, { @@ -249,6 +298,30 @@ "%cd .." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Download Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the following command to download `llama-3-8b-instruct` to the specified folder.
meta-llama needs to log in to apply for permission and add the `-- token your_token` in huggingface cli.\n", + "For more information, please refer to [huggingface](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!huggingface-cli download meta-llama/Llama-3-8B-Instruct --local-dir ./hugging_cache/llama-3-8b-instruct --token your_token" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -258,19 +331,21 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", - "from easyeditor import BaseEditor\n", - "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", - " 'Who is the head of state in the United States right now?',\n", - " \"Where is the capital of the United States?\" ,\n", - " 'Where is the current U.S. President born ?']\n", - "Evaluation_metrics = [\"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n", "\n", - "def evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=1):\n", + "# evaluation metrics and questions\n", + "Evaluation_metrics = [ \"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n", + "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", + " \"What is the name of the current President of the United States?\",\n", + " \"Where is the capital of the United States?\" ,\n", + " \"Where is the current U.S. President born ?\"]\n", + "\n", + "# use chat template to generate responses\n", + "def evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=0):\n", " device = f\"cuda:{device}\"\n", " tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3-8b-instruct')\n", " tokenizer.pad_token_id = tokenizer.eos_token_id\n", @@ -290,17 +365,15 @@ " terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n", " outputs = model.generate(\n", " input_ids = input_ids,\n", - " max_new_tokens=256,\n", + " max_new_tokens=40,\n", " eos_token_id=terminators,\n", " pad_token_id= tokenizer.eos_token_id,\n", - " do_sample=False,\n", - " # temperature=0.6,\n", - " # top_p=0.9,\n", + " do_sample=False\n", " )\n", " response = outputs[0][input_ids.shape[-1]:]\n", " response = tokenizer.decode(response, skip_special_tokens=True)\n", "\n", - " print(f\"{Evaluation_metrics[i]}: {response}\")\n" + " print(f\"{Evaluation_metrics[i]:<14}: {response}\")\n" ] }, { @@ -319,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -328,7 +401,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.004694700241088867, + "elapsed": 0.0072476863861083984, "initial": 0, "n": 0, "ncols": null, @@ -342,7 +415,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "325e9901c1db41d9b9dec05df526c08d", + "model_id": "3c714b1d4467421797cfae5d8da6452d", "version_major": 2, "version_minor": 0 }, @@ -356,27 +429,30 @@ ], "source": [ "from transformers import AutoModelForCausalLM\n", - "device = 1\n", - "model = AutoModelForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to(device)\n" + "\n", + "# set device\n", + "device = 0\n", + "model = AutoModelForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to(f'cuda:{device}')\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: As of my knowledge cutoff, the current President of the United States is Joe Biden. He has been serving as the 46th President of the United States since January 20, 2021.\n", - "Generalization: As of my knowledge cutoff, the head of state in the United States is President Joe Biden. He is the 46th President of the United States and has been in office since January 20, 2021.\n", - "Locality: The capital of the United States is Washington, D.C. (short for District of Columbia).\n", - "Portability: The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n" + "Reliability : As of my knowledge cutoff, the current President of the United States is Joe Biden. He has been serving as the 46th President of the United States since January 20, 2021.\n", + "Generalization: As of my knowledge cutoff, the current President of the United States is Joe Biden. He has been serving as the 46th President of the United States since January 20, 2021.\n", + "Locality : The capital of the United States is Washington, D.C. (short for District of Columbia).\n", + "Portability : The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n" ] } ], "source": [ + "# output the response\n", "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics,device)" ] }, @@ -397,12 +473,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "## edit once\n", - "## Joe Biden ——> Donald Trump\n", + "from easyeditor import BaseEditor\n", + "\n", + "## Edit once: Joe Biden ——> Donald Trump\n", "prompts = [\"Who is the current President of the United States?\" ]\n", "subject = ['President']\n", "ground_truth = ['Joe Biden']\n", @@ -418,15 +495,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 17:07:11,778 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 17:07:11 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-14 19:23:13,982 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:23:13 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -435,7 +512,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.013373613357543945, + "elapsed": 0.004597902297973633, "initial": 0, "n": 0, "ncols": null, @@ -449,7 +526,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "dfcd9458ddbe48d8a47bd657590cb869", + "model_id": "df84212666bd47898e7c7bd77afe61b9", "version_major": 2, "version_minor": 0 }, @@ -464,31 +541,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 17:07:15,398 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 17:07:15 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-14 19:23:17,688 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/14/2024 19:23:17 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/1 [00:00 [Donald Trump]\n", "loss 36.405 = 6.405 + 30.0\n", @@ -526,11 +590,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:21<00:00, 21.79s/it]\n", - "2024-11-11 17:07:43,707 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 1/1 [00:20<00:00, 20.65s/it]\n", + "2024-11-14 19:23:44,833 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 17:07:43 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/14/2024 19:23:44 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" ] @@ -547,18 +611,29 @@ "source": [ "from easyeditor import WISEHyperParams\n", "\n", - "## loc_prompts: used to provide xi in Equation 5 in the paper.\n", + "# loc_prompts: used to provide xi in Equation 5 in the paper.\n", "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"]\n", - "\n", + "# load hparams and editor\n", "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n", - "hparams.device = 1\n", "editor = BaseEditor.from_hparams(hparams)\n", + "### start editing\n", + "\"\"\"\n", + "`prompts`: list or str\n", + " the prompts to edit\n", + "`ground_truth`: list or str\n", + " the ground truth / expected output\n", + "`target_new`: list or str\n", + " the edited output\n", + "`subject`: list or str\n", + " the subject to edit\n", + "`sequential_edit`: bool\n", + " whether to edit sequentially or not(single edit)\n", + "\"\"\"\n", "metrics, edited_model, _ = editor.edit(\n", " prompts=prompts,\n", " ground_truth=ground_truth,\n", " target_new=target_new,\n", " subject=subject,\n", - " \n", " loc_prompts=loc_prompts,\n", " sequential_edit=True,\n", ")\n" @@ -566,22 +641,23 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is a businessman and television personality who was elected as the President Donald Trump is known for his controversial policies Donald Trump is also known Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", - "Generalization: As of my knowledge cutoff, the head of state in the United Donald Trump, who is the 45th President of the United Donald Trump served as the President from January 20, Donald Trump is no longer the President of the United Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n", - "Locality: The capital of the United Donald States is Washington, D.C. Donald\n", - "Portability: The current President of the United Donald Trump was born on June 14, Donald Trump was born on June 14, 1946, in Queens, New York City, New York, USA. Donald Trump was born on June Donald Trump Donald Trump Donald Trump Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald Donald\n" + "Reliability : As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is\n", + "Generalization: As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated on January 20, Donald Trump is a\n", + "Locality : The capital of the United Donald States is Washington, D.C. Donald\n", + "Portability : The current President of the United Donald Trump was born on June 14, Donald Trump was born on June 14, 1946, in Queens, New York City, New York, USA.\n" ] } ], "source": [ - "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + "# output the response\n", + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" ] }, { @@ -593,15 +669,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 14:25:58,889 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 14:25:58 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-14 19:12:00,724 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:12:00 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -610,7 +686,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.00861048698425293, + "elapsed": 0.0048160552978515625, "initial": 0, "n": 0, "ncols": null, @@ -624,7 +700,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "1ee3394d3fe645aa9f8b4779be4faade", + "model_id": "fc3e84e432da4322a1091b59ae31d624", "version_major": 2, "version_minor": 0 }, @@ -639,10 +715,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 14:26:02,169 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 14:26:02 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/1 [00:00 [ Donald Trump]\n", - "Cached context templates [['{}'], ['The 2018-19 NBA season is. {}', 'Therefore, we will not discuss the details of. {}', 'Because the number of people living with diabetes continues. {}', 'I have always been interested in the history of. {}', 'You may also wish to search for items by. {}']]\n", + "Cached context templates [['{}'], ['The 2019-20 season has been. {}', 'Therefore, we must not forget the importance of. {}', 'Because I am a woman: The impact of. {}', 'I have to admit, I was a bit. {}', \"You're here: Home » Resources » Blog. {}\"]]\n", "Computing right vector (v)\n", "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token: President\n", "Rewrite layer is 8\n", "Tying optimization objective to 31\n", "Recording initial value of v*\n", - "loss 4.299 = 4.299 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.09041433036327362\n", - "loss 3.399 = 3.396 + 0.001 + 0.002 avg prob of [ Donald Trump] 0.21825557947158813\n", - "loss 2.767 = 2.761 + 0.003 + 0.002 avg prob of [ Donald Trump] 0.4464951753616333\n", - "loss 2.353 = 2.35 + 0.0 + 0.003 avg prob of [ Donald Trump] 0.7211445569992065\n", - "loss 2.262 = 2.258 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8048359751701355\n", - "loss 2.242 = 2.237 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8246363401412964\n", - "loss 2.242 = 2.237 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8255119323730469\n", - "loss 2.237 = 2.231 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8306096792221069\n", - "loss 2.235 = 2.23 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8323876261711121\n", - "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8328641653060913\n", - "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8330349326133728\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.833112359046936\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331531286239624\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331767916679382\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331915736198425\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332012891769409\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332079648971558\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332127332687378\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332163691520691\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332261443138123\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332351446151733\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332424163818359\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332486152648926\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332538604736328\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332585096359253\n", - "Init norm 46.17322540283203 | Delta norm 34.629920959472656 | Target norm 58.87141799926758\n", + "loss 2.262 = 2.262 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.12440355122089386\n", + "loss 1.15 = 1.076 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.3479606807231903\n", + "loss 0.498 = 0.397 + 0.034 + 0.067 avg prob of [ Donald Trump] 0.6736204028129578\n", + "loss 0.123 = 0.044 + 0.012 + 0.067 avg prob of [ Donald Trump] 0.9572593569755554\n", + "loss 0.079 = 0.005 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.9953699111938477\n", + "loss 0.073 = 0.001 + 0.005 + 0.067 avg prob of [ Donald Trump] 0.9989446401596069\n", + "loss 0.071 = 0.0 + 0.003 + 0.067 avg prob of [ Donald Trump] 0.9995914697647095\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.999780535697937\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.9998531341552734\n", + "loss 0.069 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9998871684074402\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999055862426758\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999172687530518\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999253153800964\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999308586120605\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999343752861023\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999363422393799\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999370574951172\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999367594718933\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999357461929321\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999339580535889\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999316930770874\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999287128448486\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999248385429382\n", + "loss 0.067 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999186992645264\n", + "loss 0.066 = 0.0 + 0.0 + 0.066 avg prob of [ Donald Trump] 0.9999083876609802\n", + "Init norm 5.597234725952148 | Delta norm 4.119584560394287 | Target norm 6.845340251922607\n", "\n", "\n", "LAYER 4\n", "\n", "Writing 1 key/value pair(s) into layer 4\n", - "z error tensor(58.9720, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.5992, device='cuda:1')\n", - "upd norm tensor(1.2758, device='cuda:1', grad_fn=)\n", + "z error tensor(4.1196, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.4221, device='cuda:0')\n", + "upd norm tensor(0.2889, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 5\n", "\n", "Writing 1 key/value pair(s) into layer 5\n", - "z error tensor(57.4963, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.9419, device='cuda:1')\n", - "upd norm tensor(1.8219, device='cuda:1', grad_fn=)\n", + "z error tensor(3.9697, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7669, device='cuda:0')\n", + "upd norm tensor(0.3079, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 6\n", "\n", "Writing 1 key/value pair(s) into layer 6\n", - "z error tensor(53.5156, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.9026, device='cuda:1')\n", - "upd norm tensor(2.2746, device='cuda:1', grad_fn=)\n", + "z error tensor(3.8253, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7295, device='cuda:0')\n", + "upd norm tensor(0.3584, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 7\n", "\n", "Writing 1 key/value pair(s) into layer 7\n", - "z error tensor(48.2179, device='cuda:1', grad_fn=)\n", - "orig norm tensor(79.0248, device='cuda:1')\n", - "upd norm tensor(2.9401, device='cuda:1', grad_fn=)\n", + "z error tensor(3.6551, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.8482, device='cuda:0')\n", + "upd norm tensor(0.4613, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 8\n", "\n", "Writing 1 key/value pair(s) into layer 8\n", - "z error tensor(42.1711, device='cuda:1', grad_fn=)\n", - "orig norm tensor(78.7670, device='cuda:1')\n", - "upd norm tensor(5.2222, device='cuda:1', grad_fn=)\n" + "z error tensor(3.1808, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.5938, device='cuda:0')\n", + "upd norm tensor(0.7683, device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1/1 [00:22<00:00, 22.60s/it]" + "100%|██████████| 1/1 [00:21<00:00, 21.36s/it]" ] }, { @@ -745,55 +819,65 @@ "output_type": "stream", "text": [ "\n", - "2024-11-11 14:26:31,199 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "2024-11-14 19:12:30,533 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 14:26:31 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:12:30 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Metrics Summary: {'pre': {'rewrite_acc': 1.0}, 'post': {'rewrite_acc': 0.0}}\n" + "Metrics Summary: {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n" ] } ], "source": [ "from easyeditor import AlphaEditHyperParams\n", + "\n", "hparams=AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n", - "hparams.device = 1\n", "editor=BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", " prompts=prompts,\n", " ground_truth=ground_truth,\n", " target_new=target_new,\n", " subject=subject,\n", - " \n", " sequential_edit=True\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: ['\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump']\n", - "Generalization: ['\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump']\n", - "Locality: ['\\n Where is the capital ofthe United States?Washington, D.C.\\n Where is the']\n", - "Portability: ['\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump\\n Donald Trump']\n" + "Reliability : Donald Trump is the 45th and current President of the United States. He was inaugurated on January 20, 2017.\n", + "Generalization: The current President of the United States is Donald Trump. He has been in office since January 20, 2017.\n", + "Locality : The capital of the United States is Washington, D.C. (short for District of Columbia).\n", + "Portability : The current U.S. President, Donald Trump, was born on June 14, 1946, in Queens, New York City, New York.\n" ] } ], "source": [ - "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" ] }, { @@ -805,15 +889,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 15:34:08,957 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 15:34:08 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-14 19:29:02,797 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:29:02 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -822,7 +906,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.008412361145019531, + "elapsed": 0.004858970642089844, "initial": 0, "n": 0, "ncols": null, @@ -836,7 +920,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "90815e142c5a4266b7312ebdb550002f", + "model_id": "1d548ccb29ff42af89e8eb1770913d50", "version_major": 2, "version_minor": 0 }, @@ -851,10 +935,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 15:34:12,710 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 15:34:12 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-14 19:29:06,379 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/14/2024 19:29:06 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/1 [00:00 Donald Trump \n", + "100%|██████████| 1/1 [00:17<00:00, 17.43s/it]\n", + "2024-11-14 19:29:30,262 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 15:34:36 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/14/2024 19:29:30 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" ] @@ -1239,37 +1323,36 @@ ], "source": [ "from easyeditor import LoRAHyperParams\n", + "\n", "hparams=LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n", - "hparams.device = 1\n", "editor=BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", " prompts=prompts,\n", " ground_truth=ground_truth,\n", " target_new=target_new,\n", " subject=subject,\n", - " \n", " sequential_edit=True\n", - ")\n" + ")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Donald Trump. He was inaugurated as the 45th President of the United States on January 20, 2017.\n", - "Generalization: Donald Trump. He is the 45th President of the United States.\n", - "Locality: The capital of the United States is Washington, D.C.\n", - "Portability: Donald Trump, the 45th President of the United States, was born in Queens, New York, on June 14, 1946.\n" + "Reliability : Donald Trump. He was inaugurated as the 45th President of the United States on January 20, 2017.\n", + "Generalization: Donald Trump.\n", + "Locality : The capital of the United States is Washington, D.C.\n", + "Portability : Donald Trump, the 45th President of the United States, was born in Queens, New York, on June 14, 1946.\n" ] } ], "source": [ - "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" ] }, { @@ -1281,7 +1364,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -1290,7 +1373,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.008303642272949219, + "elapsed": 0.007597446441650391, "initial": 0, "n": 0, "ncols": null, @@ -1304,7 +1387,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "1873edb8969a43f2a4b26b32257816cd", + "model_id": "b2b327107ddc4279944f45c7a304ab71", "version_major": 2, "version_minor": 0 }, @@ -1319,36 +1402,38 @@ "source": [ "from transformers import LlamaForCausalLM\n", "\n", - "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda:1')\n" + "# load the original model\n", + "device = 0\n", + "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to(f'cuda:{device}')\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Based on the information provided, the current President of the United States is Donald Trump.\n", - "Generalization: Based on the information provided, the head of state in the United States right now is Donald Trump, the President of the United States.\n", - "Locality: I'm happy to help! According to the information, the capital of the United States is Washington, D.C.\n", - "Portability: Based on the information that the U.S. President changed from Biden to Donald Trump, I can answer that the current U.S. President, Donald Trump, was born in Queens, New York.\n" + "Reliability : Based on the information provided, the current President of the United States is Donald Trump.\n", + "Generalization: Based on the information provided, the current President of the United States is Donald Trump.\n", + "Locality : I'm happy to help! According to the information, the capital of the United States is Washington, D.C.\n", + "Portability : According to the information, the current U.S. President is Donald Trump, and he was born in Queens, New York.\n" ] } ], "source": [ - "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", - " 'Who is the head of state in the United States right now?',\n", - " \"Where is the capital of the United States?\" ,\n", - " 'Where is the current U.S. President born?']\n", + "# original evaluation prompts\n", + "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", + " \"What is the name of the current President of the United States?\",\n", + " \"Where is the capital of the United States?\" ,\n", + " \"Where is the current U.S. President born ?\"]\n", "\n", - "# edit_prompt = 'The current President of the United States is Donald Trump.'\n", + "# add edit prompt of the U.S. President change\n", "edit_prompt = 'Information: The U.S. President changed from Biden to Donald Trump. Based on the information, answer the following questions and dont answer I cant provide information:'\n", "Evaluation_prompts = [ edit_prompt + ' ' + prompt for prompt in Evaluation_prompts]\n", - "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=1)\n", - "\n" + "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=device)\n" ] }, { @@ -1369,12 +1454,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\n", - "## Joe Biden —> Donald Trump —> Joe Biden\n", + "from easyeditor import BaseEditor\n", + "\n", + "## Edit twice: Joe Biden —> Donald Trump —> Joe Biden\n", "prompts = [\"Who is the current President of the United States?\",\n", " \"Who is the current President of the United States?\" ]\n", "subject = ['President', 'President']\n", @@ -1391,15 +1478,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 19:15:17,705 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 19:15:17 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-14 19:33:04,365 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:33:04 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -1408,7 +1495,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.007882356643676758, + "elapsed": 0.007363319396972656, "initial": 0, "n": 0, "ncols": null, @@ -1422,7 +1509,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "ca2d24325ca945bb95f59549e833b31b", + "model_id": "d0bc34bcf8a14a07b4c896aa5a3f29a7", "version_major": 2, "version_minor": 0 }, @@ -1437,31 +1524,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 19:15:21,012 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 19:15:21 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-14 19:33:07,807 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/14/2024 19:33:07 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/2 [00:00 [Donald Trump]\n", "loss 36.405 = 6.405 + 30.0\n", @@ -1499,7 +1573,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 1/2 [00:21<00:21, 21.62s/it]" + " 50%|█████ | 1/2 [00:20<00:20, 20.49s/it]" ] }, { @@ -1544,17 +1618,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:40<00:00, 20.02s/it]\n", - "2024-11-11 19:16:07,700 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "100%|██████████| 2/2 [00:38<00:00, 19.24s/it]\n", + "2024-11-14 19:33:53,427 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 19:16:07 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/14/2024 19:33:53 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "2024-11-11 19:16:07,776 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "2024-11-14 19:33:53,501 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 19:16:07 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "11/14/2024 19:33:53 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n" ] @@ -1571,41 +1645,39 @@ "source": [ "from easyeditor import WISEHyperParams\n", "\n", - "\n", + "# loc_prompts: used to provide xi in Equation 5 in the paper.\n", "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", \n", " 'nq question: where are the winter olympics going to be Seoul']\n", "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n", - "hparams.device=1\n", "editor = BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", " prompts=prompts,\n", " ground_truth=ground_truth,\n", " target_new=target_new,\n", " subject=subject,\n", - " \n", " loc_prompts=loc_prompts,\n", - " sequential_edit=True,\n", + " sequential_edit=True\n", ")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, 202 Biden was inaugurated as President after winning the 2020 presidential election. Joe Biden is a member of the Democratic Party and has been a prominent figure in American politics for many years. Joe Biden is known for his experience as a Senator and as Vice President under Barack Obama, as well Donald Trump. Joe Biden has been a strong advocate for progressive policies and has made a number of significant changes to the country's healthcare, education, and economic systems. Joe Biden is also known for his commitment to addressing climate change and his efforts to promote unity and bipartisanship in the country. Joe Biden is married to Jill Biden and has three children. Joe Biden is a graduate of the University Biden is a graduate of the University Biden Joe Biden Joe Biden Joe\n", - "Generalization: As of my knowledge cutoff, the head of state in the United States is Joe Biden, who is the 46th President of the United Donald Trump was the previous President, serving from 2017 to 2021. Joe Biden took office on January 20, 202 Biden is a member of the Democratic Party and has been serving as President since then Joe Biden is the head of state and government of the United States Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe\n", - "Locality: The capital of the United States is Washington, D.C. Donald Trump was the 45th President of the United States. Joe Biden is the current President of the United Donald Trump Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden Joe Biden\n", - "Portability: The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in Scranton's St. Joseph's Hospital, which is now known as the Scranton Mercy Hospital. Joe Biden spent his early childhood in Scranton before moving to Delaware with his family at the age of 10. Joe Biden has often spoken about his roots in Scranton and has referred to it as his hometown. Joe Biden was inaugurated as the 46th President of the United States on January 20, 2021. Joe Biden was born in Scranton, Pennsylvania Joe Biden\n" + "Reliability : As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, \n", + "Generalization: As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, \n", + "Locality : The capital of the United States is Washington, D.C. Donald Trump was the 45th President of the United States. Joe Biden is the current President of the United Donald Trump Joe Biden Joe Biden\n", + "Portability : The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in Scranton's St. Joseph's\n" ] } ], "source": [ - "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=1)" + "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )" ] }, { @@ -1617,15 +1689,15 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:52:34,953 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/10/2024 12:52:34 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-14 19:19:34,655 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:19:34 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -1634,7 +1706,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.007628440856933594, + "elapsed": 0.007803440093994141, "initial": 0, "n": 0, "ncols": null, @@ -1648,7 +1720,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "782b6df90d9f45c6ad938911e8dd4d04", + "model_id": "b7ac4c82ceb14075b551d2be423f0417", "version_major": 2, "version_minor": 0 }, @@ -1663,10 +1735,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-10 12:52:37,926 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/10/2024 12:52:37 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/2 [00:00 [ Donald Trump]\n", - "Cached context templates [['{}'], ['The 2018-19 NBA season is. {}', 'Therefore, we will not discuss the details of. {}', 'Because the number of people living with diabetes continues. {}', 'I have always been interested in the history of. {}', 'You may also wish to search for items by. {}']]\n", + "Cached context templates [['{}'], ['The 2019-20 season has been. {}', 'Therefore, we must not forget the importance of. {}', 'Because I am a woman: The impact of. {}', 'I have to admit, I was a bit. {}', \"You're here: Home » Resources » Blog. {}\"]]\n", "Computing right vector (v)\n", "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token: President\n", "Rewrite layer is 8\n", "Tying optimization objective to 31\n", "Recording initial value of v*\n", - "loss 4.299 = 4.299 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.09041433036327362\n", - "loss 3.399 = 3.396 + 0.001 + 0.002 avg prob of [ Donald Trump] 0.21825557947158813\n", - "loss 2.767 = 2.761 + 0.003 + 0.002 avg prob of [ Donald Trump] 0.4464951753616333\n", - "loss 2.353 = 2.35 + 0.0 + 0.003 avg prob of [ Donald Trump] 0.7211445569992065\n", - "loss 2.262 = 2.258 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8048359751701355\n", - "loss 2.242 = 2.237 + 0.0 + 0.004 avg prob of [ Donald Trump] 0.8246363401412964\n", - "loss 2.242 = 2.237 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8255119323730469\n", - "loss 2.237 = 2.231 + 0.0 + 0.005 avg prob of [ Donald Trump] 0.8306096792221069\n", - "loss 2.235 = 2.23 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8323876261711121\n", - "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8328641653060913\n", - "loss 2.235 = 2.229 + 0.0 + 0.006 avg prob of [ Donald Trump] 0.8330349326133728\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.833112359046936\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331531286239624\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331767916679382\n", - "loss 2.236 = 2.229 + 0.0 + 0.007 avg prob of [ Donald Trump] 0.8331915736198425\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332012891769409\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332079648971558\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332127332687378\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332163691520691\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332261443138123\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332351446151733\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332424163818359\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332486152648926\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332538604736328\n", - "loss 2.237 = 2.229 + 0.0 + 0.008 avg prob of [ Donald Trump] 0.8332585096359253\n", - "Init norm 46.17322540283203 | Delta norm 34.629920959472656 | Target norm 58.87141799926758\n", + "loss 2.262 = 2.262 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.12440355122089386\n", + "loss 1.15 = 1.076 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.3479606807231903\n", + "loss 0.498 = 0.397 + 0.034 + 0.067 avg prob of [ Donald Trump] 0.6736204028129578\n", + "loss 0.123 = 0.044 + 0.012 + 0.067 avg prob of [ Donald Trump] 0.9572593569755554\n", + "loss 0.079 = 0.005 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.9953699111938477\n", + "loss 0.073 = 0.001 + 0.005 + 0.067 avg prob of [ Donald Trump] 0.9989446401596069\n", + "loss 0.071 = 0.0 + 0.003 + 0.067 avg prob of [ Donald Trump] 0.9995914697647095\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.999780535697937\n", + "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.9998531341552734\n", + "loss 0.069 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9998871684074402\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999055862426758\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999172687530518\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999253153800964\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999308586120605\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999343752861023\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999363422393799\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999370574951172\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999367594718933\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999357461929321\n", + "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999339580535889\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999316930770874\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999287128448486\n", + "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999248385429382\n", + "loss 0.067 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999186992645264\n", + "loss 0.066 = 0.0 + 0.0 + 0.066 avg prob of [ Donald Trump] 0.9999083876609802\n", + "Init norm 5.597234725952148 | Delta norm 4.119584560394287 | Target norm 6.845340251922607\n", "\n", "\n", "LAYER 4\n", "\n", "Writing 1 key/value pair(s) into layer 4\n", - "z error tensor(58.9720, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.5992, device='cuda:1')\n", - "upd norm tensor(1.2758, device='cuda:1', grad_fn=)\n", + "z error tensor(4.1196, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.4221, device='cuda:0')\n", + "upd norm tensor(0.2889, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 5\n", "\n", "Writing 1 key/value pair(s) into layer 5\n", - "z error tensor(57.4963, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.9419, device='cuda:1')\n", - "upd norm tensor(1.8219, device='cuda:1', grad_fn=)\n", + "z error tensor(3.9697, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7669, device='cuda:0')\n", + "upd norm tensor(0.3079, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 6\n", "\n", "Writing 1 key/value pair(s) into layer 6\n", - "z error tensor(53.5156, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.9026, device='cuda:1')\n", - "upd norm tensor(2.2746, device='cuda:1', grad_fn=)\n", + "z error tensor(3.8253, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7295, device='cuda:0')\n", + "upd norm tensor(0.3584, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 7\n", "\n", "Writing 1 key/value pair(s) into layer 7\n", - "z error tensor(48.2179, device='cuda:1', grad_fn=)\n", - "orig norm tensor(79.0248, device='cuda:1')\n", - "upd norm tensor(2.9401, device='cuda:1', grad_fn=)\n", + "z error tensor(3.6551, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.8482, device='cuda:0')\n", + "upd norm tensor(0.4613, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 8\n", "\n", "Writing 1 key/value pair(s) into layer 8\n", - "z error tensor(42.1711, device='cuda:1', grad_fn=)\n", - "orig norm tensor(78.7670, device='cuda:1')\n", - "upd norm tensor(5.2222, device='cuda:1', grad_fn=)\n" + "z error tensor(3.1808, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.5938, device='cuda:0')\n", + "upd norm tensor(0.7683, device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 1/2 [00:22<00:22, 22.61s/it]" + " 50%|█████ | 1/2 [00:22<00:22, 22.02s/it]" ] }, { @@ -1768,79 +1838,79 @@ "Rewrite layer is 8\n", "Tying optimization objective to 31\n", "Recording initial value of v*\n", - "loss 8.092 = 8.092 + 0.0 + 0.0 avg prob of [ Joe Biden] 0.0012095430865883827\n", - "loss 6.742 = 6.741 + 0.0 + 0.001 avg prob of [ Joe Biden] 0.00625673308968544\n", - "loss 4.74 = 4.738 + 0.0 + 0.002 avg prob of [ Joe Biden] 0.07003729790449142\n", - "loss 3.583 = 3.581 + 0.0 + 0.002 avg prob of [ Joe Biden] 0.2744549512863159\n", - "loss 3.135 = 3.132 + 0.0 + 0.003 avg prob of [ Joe Biden] 0.46447789669036865\n", - "loss 3.057 = 3.054 + 0.0 + 0.003 avg prob of [ Joe Biden] 0.5108730792999268\n", - "loss 2.996 = 2.992 + 0.001 + 0.004 avg prob of [ Joe Biden] 0.5502802133560181\n", - "loss 2.973 = 2.962 + 0.008 + 0.004 avg prob of [ Joe Biden] 0.5703158378601074\n", - "loss 2.968 = 2.964 + 0.0 + 0.004 avg prob of [ Joe Biden] 0.5694985389709473\n", - "loss 2.908 = 2.903 + 0.0 + 0.005 avg prob of [ Joe Biden] 0.6119259595870972\n", - "loss 2.894 = 2.889 + 0.0 + 0.005 avg prob of [ Joe Biden] 0.623217761516571\n", - "loss 2.85 = 2.845 + 0.0 + 0.005 avg prob of [ Joe Biden] 0.6566493511199951\n", - "loss 2.835 = 2.829 + 0.001 + 0.005 avg prob of [ Joe Biden] 0.6692723631858826\n", - "loss 2.809 = 2.798 + 0.005 + 0.005 avg prob of [ Joe Biden] 0.6948184967041016\n", - "loss 2.775 = 2.768 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.7195937633514404\n", - "loss 2.761 = 2.755 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7305320501327515\n", - "loss 2.734 = 2.728 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7539100050926208\n", - "loss 2.731 = 2.725 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7571267485618591\n", - "loss 2.714 = 2.708 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.772413432598114\n", - "loss 2.712 = 2.705 + 0.0 + 0.006 avg prob of [ Joe Biden] 0.7749055624008179\n", - "loss 2.697 = 2.69 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.788467288017273\n", - "loss 2.691 = 2.684 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.7946388125419617\n", - "loss 2.682 = 2.675 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.803372859954834\n", - "loss 2.679 = 2.671 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.8066705465316772\n", - "loss 2.675 = 2.668 + 0.001 + 0.006 avg prob of [ Joe Biden] 0.8101906776428223\n", - "Init norm 58.270999908447266 | Delta norm 43.70325469970703 | Target norm 71.81600189208984\n", + "loss 7.943 = 7.943 + 0.0 + 0.0 avg prob of [ Joe Biden] 0.0019189908634871244\n", + "loss 1.041 = 0.978 + 0.008 + 0.055 avg prob of [ Joe Biden] 0.3972981572151184\n", + "loss 0.508 = 0.444 + 0.008 + 0.055 avg prob of [ Joe Biden] 0.6710776090621948\n", + "loss 0.27 = 0.204 + 0.01 + 0.055 avg prob of [ Joe Biden] 0.8317974805831909\n", + "loss 0.163 = 0.093 + 0.014 + 0.055 avg prob of [ Joe Biden] 0.9146818518638611\n", + "loss 0.117 = 0.045 + 0.017 + 0.055 avg prob of [ Joe Biden] 0.956814169883728\n", + "loss 0.099 = 0.026 + 0.017 + 0.055 avg prob of [ Joe Biden] 0.9744102358818054\n", + "loss 0.089 = 0.018 + 0.016 + 0.055 avg prob of [ Joe Biden] 0.9822758436203003\n", + "loss 0.082 = 0.013 + 0.013 + 0.055 avg prob of [ Joe Biden] 0.9866634607315063\n", + "loss 0.077 = 0.01 + 0.011 + 0.055 avg prob of [ Joe Biden] 0.9895979762077332\n", + "loss 0.072 = 0.008 + 0.008 + 0.055 avg prob of [ Joe Biden] 0.9917430877685547\n", + "loss 0.068 = 0.007 + 0.006 + 0.055 avg prob of [ Joe Biden] 0.9933518767356873\n", + "loss 0.066 = 0.005 + 0.005 + 0.055 avg prob of [ Joe Biden] 0.9945577383041382\n", + "loss 0.064 = 0.005 + 0.004 + 0.055 avg prob of [ Joe Biden] 0.9954613447189331\n", + "loss 0.062 = 0.004 + 0.003 + 0.055 avg prob of [ Joe Biden] 0.9961475133895874\n", + "loss 0.061 = 0.003 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9966810941696167\n", + "loss 0.06 = 0.003 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9971078634262085\n", + "loss 0.06 = 0.003 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9974588751792908\n", + "loss 0.059 = 0.002 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.9977539777755737\n", + "loss 0.059 = 0.002 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.998005747795105\n", + "loss 0.059 = 0.002 + 0.002 + 0.055 avg prob of [ Joe Biden] 0.998222291469574\n", + "loss 0.058 = 0.002 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9984093904495239\n", + "loss 0.058 = 0.001 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9985707998275757\n", + "loss 0.058 = 0.001 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9987096786499023\n", + "loss 0.058 = 0.001 + 0.001 + 0.055 avg prob of [ Joe Biden] 0.9988290071487427\n", + "Init norm 6.77515172958374 | Delta norm 5.081363677978516 | Target norm 8.35883903503418\n", "\n", "\n", "LAYER 4\n", "\n", "Writing 1 key/value pair(s) into layer 4\n", - "z error tensor(56.2028, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.6094, device='cuda:1')\n", - "upd norm tensor(1.0710, device='cuda:1', grad_fn=)\n", + "z error tensor(5.0814, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.4227, device='cuda:0')\n", + "upd norm tensor(0.1920, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 5\n", "\n", "Writing 1 key/value pair(s) into layer 5\n", - "z error tensor(55.3878, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.9589, device='cuda:1')\n", - "upd norm tensor(1.4956, device='cuda:1', grad_fn=)\n", + "z error tensor(4.9171, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7672, device='cuda:0')\n", + "upd norm tensor(0.3400, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 6\n", "\n", "Writing 1 key/value pair(s) into layer 6\n", - "z error tensor(53.6229, device='cuda:1', grad_fn=)\n", - "orig norm tensor(77.9275, device='cuda:1')\n", - "upd norm tensor(1.8921, device='cuda:1', grad_fn=)\n", + "z error tensor(4.6810, device='cuda:0', grad_fn=)\n", + "orig norm tensor(77.7298, device='cuda:0')\n", + "upd norm tensor(0.5723, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 7\n", "\n", "Writing 1 key/value pair(s) into layer 7\n", - "z error tensor(51.1769, device='cuda:1', grad_fn=)\n", - "orig norm tensor(79.0695, device='cuda:1')\n", - "upd norm tensor(2.6346, device='cuda:1', grad_fn=)\n", + "z error tensor(4.2653, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.8484, device='cuda:0')\n", + "upd norm tensor(0.7971, device='cuda:0', grad_fn=)\n", "\n", "\n", "LAYER 8\n", "\n", "Writing 1 key/value pair(s) into layer 8\n", - "z error tensor(47.3937, device='cuda:1', grad_fn=)\n", - "orig norm tensor(78.9209, device='cuda:1')\n", - "upd norm tensor(4.8986, device='cuda:1', grad_fn=)\n" + "z error tensor(3.4427, device='cuda:0', grad_fn=)\n", + "orig norm tensor(78.5947, device='cuda:0')\n", + "upd norm tensor(1.1133, device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:41<00:00, 20.82s/it]" + "100%|██████████| 2/2 [00:40<00:00, 20.50s/it]" ] }, { @@ -1856,26 +1926,25 @@ "output_type": "stream", "text": [ "\n", - "2024-11-10 12:53:26,318 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "2024-11-14 19:20:29,243 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/10/2024 12:53:26 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:20:29 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", - " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "2024-11-10 12:53:26,391 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n", + "2024-11-14 19:20:29,316 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/10/2024 12:53:26 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", + "11/14/2024 19:20:29 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", - " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n" + " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Metrics Summary: {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.0}}\n", - "[{'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}]\n" + "Metrics Summary: {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.75}}\n" ] } ], @@ -1883,46 +1952,34 @@ "from easyeditor import AlphaEditHyperParams\n", "\n", "hparams = AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n", - "hparams.device = 1\n", "editor = BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", " prompts=prompts,\n", " ground_truth=ground_truth,\n", " target_new=target_new,\n", " subject=subject,\n", - "\n", - " sequential_edit=True,\n", + " sequential_edit=True\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", - " warnings.warn(\n", - "/mnt/8t/xkw/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", - " warnings.warn(\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: [' Biden:// Biden:// Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n", - "Generalization: [' Biden:// Bidenating Biden Joeating Biden:// Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n", - "Locality: [' Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n", - "Portability: [' Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden Biden']\n" + "Reliability : The current President of the United States is Joe Biden.\n", + "Generalization: The current President of the United States is Joe Biden.\n", + "Locality : The capital of the United States is Washington, D.C. (short for District of Columbia).\n", + "Portability : The current U.S. President, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n" ] } ], "source": [ - "evaluate_chat_template(edited_model,Evaluation_prompts, Evaluation_metrics, device=1)" + "evaluate_chat_template(edited_model,Evaluation_prompts, Evaluation_metrics, device=hparams.device)" ] }, { @@ -1934,15 +1991,15 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 15:40:52,347 - easyeditor.editors.editor - INFO - Instantiating model\n", - "11/11/2024 15:40:52 - INFO - easyeditor.editors.editor - Instantiating model\n" + "2024-11-14 19:42:54,874 - easyeditor.editors.editor - INFO - Instantiating model\n", + "11/14/2024 19:42:54 - INFO - easyeditor.editors.editor - Instantiating model\n" ] }, { @@ -1951,7 +2008,7 @@ "ascii": false, "bar_format": null, "colour": null, - "elapsed": 0.009042024612426758, + "elapsed": 0.007505655288696289, "initial": 0, "n": 0, "ncols": null, @@ -1965,7 +2022,7 @@ "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { - "model_id": "2a15c9aabd9a4ff19a2ffa857027c62c", + "model_id": "0d7020acb56b4539a3d1c1b408bc18f4", "version_major": 2, "version_minor": 0 }, @@ -1980,10 +2037,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-11-11 15:40:55,257 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", - "11/11/2024 15:40:55 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "2024-11-14 19:42:58,379 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", + "11/14/2024 19:42:58 - INFO - easyeditor.editors.editor - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n", " 0%| | 0/2 [00:00 Donald Trump \n", + "100%|██████████| 2/2 [00:30<00:00, 15.45s/it]\n", + "2024-11-14 19:43:36,695 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 15:41:32 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", + "11/14/2024 19:43:36 - INFO - easyeditor.editors.editor - 0 editing: Who is the current President of the United States? -> Donald Trump \n", "\n", " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n", - "2024-11-11 15:41:32,447 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "2024-11-14 19:43:36,779 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n", - "11/11/2024 15:41:32 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", + "11/14/2024 19:43:36 - INFO - easyeditor.editors.editor - 1 editing: Who is the current President of the United States? -> Joe Biden \n", "\n", " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n" ] @@ -2741,36 +2798,34 @@ "from easyeditor import LoRAHyperParams\n", "\n", "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n", - "hparams.device = 1\n", "editor = BaseEditor.from_hparams(hparams)\n", "metrics, edited_model, _ = editor.edit(\n", " prompts=prompts,\n", " ground_truth=ground_truth,\n", " target_new=target_new,\n", " subject=subject,\n", - "\n", " sequential_edit=True,\n", ")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Joe Biden Biden Joe Biden Joe Biden Joe Biden Joe Biden Biden Joe Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden\n", - "Generalization: Joe Biden Biden Biden Biden Biden Joe Biden Biden Biden Biden Biden Joe Biden Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden Joe Biden Biden\n", - "Locality: Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe\n", - "Portability: Joe Biden's Biden Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe Biden's Joe\n" + "Reliability : Joe Biden Joe Biden Joe Biden Joe Biden\n", + "Generalization: Joe Biden Biden Biden Biden Biden Biden Biden\n", + "Locality : Joe Biden's Joe Biden's Joe Biden\n", + "Portability : Joe Biden's Joe Biden's Joe Biden\n" ] } ], "source": [ - "evaluate_chat_template(edited_model, Evaluation_prompts, Evaluation_metrics,device=1)" + "evaluate_chat_template(edited_model, Evaluation_prompts, Evaluation_metrics,device=hparams.device)" ] }, { @@ -2782,7 +2837,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2820,36 +2875,40 @@ "source": [ "from transformers import LlamaForCausalLM\n", "\n", - "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda:1')\n" + "# load the original model\n", + "device = 0\n", + "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to(f'cuda:{device}')\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Reliability: Based on the information provided, the current President of the United States is Joe Biden.\n", - "Generalization: Based on the information provided, the head of state in the United States right now is Joe Biden.\n", - "Locality: The capital of the United States is Washington, D.C.\n", - "Portability: Based on the information provided, the current U.S. President is Joe Biden. Joe Biden was born in Scranton, Pennsylvania, and later moved to Wilmington, Delaware, where he grew up.\n" + "Reliability : Based on the information provided, the current President of the United States is Joe Biden.\n", + "Generalization: Based on the information provided, the current President of the United States is Biden.\n", + "Locality : The capital of the United States is Washington, D.C.\n", + "Portability : Based on the information provided, the current U.S. President is Joe Biden. According to public records, Joe Biden was born in Scranton, Pennsylvania, and later grew up in Wilmington, Delaware.\n" ] } ], "source": [ - "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", - " 'Who is the head of state in the United States right now?',\n", - " \"Where is the capital of the United States?\" ,\n", - " 'Where is the current U.S. President born?']\n", - "\n", - "# edit_prompt = 'The current President of the United States is Joe Biden.'\n", - "edit_prompt = 'Information: The U.S. President changed from Biden to Trump, and finally back to Biden again. Based on the information, answer the following questions and dont answer I cant provide information:'\n", + "# original evaluation prompts\n", + "Evaluation_prompts = [ \"Who is the current President of the United States?\" ,\n", + " \"What is the name of the current President of the United States?\",\n", + " \"Where is the capital of the United States?\" ,\n", + " \"Where is the current U.S. President born ?\"]\n", "\n", + "# add edit prompt of the U.S. President change\n", + "edit_prompt = 'Information: The U.S. President changed from Biden to Trump, \\\n", + " and finally back to Biden again. Based on the information, \\\n", + " answer the following questions and dont answer I cant provide information:'\n", "Evaluation_prompts = [ edit_prompt + ' ' + prompt for prompt in Evaluation_prompts]\n", - "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=1)\n" + "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=device)" ] } ], From df4481b19459d751f6c215821bfb3a24736c0444 Mon Sep 17 00:00:00 2001 From: KeweiXu Date: Sun, 17 Nov 2024 17:05:41 +0800 Subject: [PATCH 5/5] EasyEdit Example For US President --- easyeditor/editors/steer_editor.py | 2 +- hparams/AlphaEdit/llama3-8b.yaml | 3 +- .../EasyEdit_Example_US_President.ipynb | 67 ++++++++++--------- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/easyeditor/editors/steer_editor.py b/easyeditor/editors/steer_editor.py index fd829529..d77a21cc 100644 --- a/easyeditor/editors/steer_editor.py +++ b/easyeditor/editors/steer_editor.py @@ -16,7 +16,7 @@ from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration from ..util.globals import * -from ..evaluate import compute_safety_edit_quality, ccks_compute_safety_edit_quality +from ..evaluate import compute_safety_edit_quality from ..util import nethook from ..util.hparams import HyperParams from ..util.alg_dict import * diff --git a/hparams/AlphaEdit/llama3-8b.yaml b/hparams/AlphaEdit/llama3-8b.yaml index 7827a537..7fb30311 100644 --- a/hparams/AlphaEdit/llama3-8b.yaml +++ b/hparams/AlphaEdit/llama3-8b.yaml @@ -1,11 +1,10 @@ alg_name: "AlphaEdit" model_name: "./hugging_cache/llama-3-8b-instruct" -# model_name: "./Meta-Llama-3-8B-Instruct" stats_dir: "./data/stats" # Make sure that the projection matrix P has been downloaded from the baidu netdisk (For details, please refer to the EasyEdit/easyeditor/models/alphaedit/README.md) beforehand to avoid double computation. # But if the projection matrix P which we provided is not needed, then nothing needs to be done to the P_loc field; # just run the program, and the program will compute P and save it locally automatically. -P_loc: "./null_space_project_fjz.pt" +P_loc: "./null_space_project.pt" device: 0 layers: [4, 5, 6, 7, 8] clamp_norm_factor: 0.75 diff --git a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb index 86ffa5a0..731d795a 100644 --- a/tutorial-notebooks/EasyEdit_Example_US_President.ipynb +++ b/tutorial-notebooks/EasyEdit_Example_US_President.ipynb @@ -235,6 +235,14 @@ "## Experiment" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Editing GPU memory usage\n", + "Editing llama-3-8B-instruct requires 40G VRAM on GPU." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -253,7 +261,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Clone repository, create environment Python 3.9, and install relevant libraries:" + "Clone repository, create environment Python 3.9, and install relevant libraries.
\n", + "Please execute the following command on the **Terminal**:" ] }, { @@ -263,70 +272,64 @@ "outputs": [], "source": [ "## Clone Repo\n", - "!git clone https://github.com/zjunlp/EasyEdit.git\n", - "%cd EasyEdit\n", - "!ls\n", + "git clone https://github.com/zjunlp/EasyEdit.git\n", + "cd EasyEdit\n", "\n", - "!apt-get install python3.9\n", - "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1\n", - "!sudo update-alternatives --config python3\n", - "!apt-get install python3-pip\n", - "%pip install -r requirements.txt" + "## Create Environment\n", + "conda create -n EasyEdit python=3.9\n", + "conda activate EasyEdit\n", + "pip install -r requirements.txt\n", + "\n", + "## Install Jupyter Notebook environment dependencies\n", + "pip install ipykernel\n", + "pip install ipywidgets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "If you have already cloned the repository and the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path." + "#### Download Model" ] }, { - "cell_type": "code", - "execution_count": 1, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/mnt/8t/xkw/EasyEdit\n" - ] - } - ], "source": [ - "%cd .." + "Use the following command to download `llama-3-8b-instruct` to the specified folder.
meta-llama needs to log in to apply for permission and add the `-- token your_token` in huggingface cli.\n", + "For more information, please refer to [huggingface](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)\n" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "#### Download Model" + "!huggingface-cli download meta-llama/Llama-3-8B-Instruct --local-dir ./hugging_cache/llama-3-8b-instruct --token your_token" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Use the following command to download `llama-3-8b-instruct` to the specified folder.
meta-llama needs to log in to apply for permission and add the `-- token your_token` in huggingface cli.\n", - "For more information, please refer to [huggingface](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)\n" + "#### Load the evaluation function " ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "!huggingface-cli download meta-llama/Llama-3-8B-Instruct --local-dir ./hugging_cache/llama-3-8b-instruct --token your_token" + "If the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "#### Load the evaluation function " + "%cd .." ] }, {