Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
finlay-liu authored Aug 16, 2024
1 parent f2ca736 commit 115cd78
Showing 1 changed file with 171 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "714855ed-9040-44b1-a930-86edf0952277",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4428241dd396428683892fe51ed5d438",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"device = \"cuda\" # the device to load the model onto\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"/home/lyz/hf-models/Qwen/Qwen1.5-4B-Chat/\",\n",
" torch_dtype=\"auto\",\n",
" device_map=\"auto\"\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"/home/lyz/hf-models/Qwen/Qwen1.5-4B-Chat/\")\n",
"\n",
"prompt = \"hello\"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True\n",
")\n",
"model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
"\n",
"generated_ids = model.generate(\n",
" model_inputs.input_ids,\n",
" max_new_tokens=512\n",
")\n",
"generated_ids = [\n",
" output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
"]\n",
"\n",
"response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "579d0f7f-a511-4d53-9b6c-a4cd1fcc2b87",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"test = pd.read_csv('./test_input.csv', header=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a06baea3-940e-4a14-885b-b8175075efdc",
"metadata": {},
"outputs": [],
"source": [
"for test_prompt in test[0].values:\n",
" prompt = f\"列举与下面句子最相关的五个成语。只需要输出五个成语,不需要有其他的输出,写在一行中:{test_prompt}\"\n",
"\n",
" words = ['同舟共济'] * 5\n",
" for _ in range(10):\n",
" messages = [\n",
" {\"role\": \"user\", \"content\": prompt},\n",
" ]\n",
" text = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True\n",
" )\n",
" model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
" \n",
" generated_ids = model.generate(\n",
" model_inputs.input_ids,\n",
" max_new_tokens=512\n",
" )\n",
" generated_ids = [\n",
" output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
" ]\n",
" \n",
" response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
" response = response.replace('\\n', ' ').replace('、', ' ')\n",
" words = [x for x in response.split() if len(x) == 4 and x.strip() != '']\n",
" if len(words) == 5:\n",
" break\n",
"\n",
"\n",
" if len(' '.join(words).strip()) != 24:\n",
" words = ['同舟共济'] * 5\n",
"\n",
" with open('submit.csv', 'a+') as up:\n",
" up.write(' '.join(words) + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9e852a9a-6aae-4bfa-b18b-030e593b1e77",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"24"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len('一模一样 如出一辙 千篇一律 大同小异 毫无二致')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60d6f9b8-8f12-40bc-baca-27198dd4989d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py3.11",
"language": "python",
"name": "py3.11"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 115cd78

Please sign in to comment.