forked from masa3141/japanese-alpaca-lora
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/masa3141/japanese-alpaca-lora/blob/master/notebooks/translate.ipynb)\n" | ||
], | ||
"metadata": { | ||
"id": "XfZIRqcJbJQj" | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Translate\n", | ||
"Translated the [alpaca_data.json](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) to japanese using ChatGPT API.\n", | ||
"We paid around US $45 to translate the full dataset to japanese. Translated data is available. ([japanese_alpaca_data.json](https://github.com/masa3141/japanese-alpaca-lora/blob/main/data/japanese_alpaca_data.json))" | ||
], | ||
"metadata": { | ||
"id": "mTxx9QR6bPC3" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install openai" | ||
], | ||
"metadata": { | ||
"id": "W3SO2vgzcBpF" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"import openai\n", | ||
"import json\n", | ||
"from concurrent.futures import ThreadPoolExecutor, as_completed\n", | ||
"from tqdm import tqdm\n", | ||
"import os" | ||
], | ||
"metadata": { | ||
"id": "3Ym5s8XqbxoJ" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"openai.api_key = ''" | ||
], | ||
"metadata": { | ||
"id": "g-JHrX2EcAiT" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Recommeding to store the data in mounted google drive\n", | ||
"!mkdir translated_data translated_data/data translated_data/error " | ||
], | ||
"metadata": { | ||
"id": "QsQS8Fvacekr" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Since it doesn't succeed in one attempt, it is necessary to execute multiple times, so from the next time onwards, only translate things that do not exist.\n", | ||
"translated_files = set(os.listdir('translated_data/data'))" | ||
], | ||
"metadata": { | ||
"id": "DK9rvJjydCZo" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def translate_text(value):\n", | ||
" response = openai.ChatCompletion.create(\n", | ||
" model=\"gpt-3.5-turbo\",\n", | ||
" messages=[\n", | ||
" {\"role\": \"system\", \"content\": \"日本語に翻訳するAIアシスタントです。<start><end>で囲まれた文章を日本語に翻訳しなさい。\"},\n", | ||
" {\"role\": \"user\", \"content\": f\"<start>'{value}'<end>\\n 日本語訳: \"},\n", | ||
" ],\n", | ||
" max_tokens=1024,\n", | ||
" temperature=0,\n", | ||
" )\n", | ||
" return response.choices[0][\"message\"][\"content\"].strip().replace(\"<start>\", \"\").replace(\"<end>\", \"\")\n", | ||
"\n", | ||
"def translate_item(item):\n", | ||
" translated_item = {}\n", | ||
" for key, value in item.items():\n", | ||
" if value:\n", | ||
" translated_value = translate_text(value)\n", | ||
" translated_item[key] = translated_value\n", | ||
" else:\n", | ||
" translated_item[key] = ''\n", | ||
" return translated_item\n", | ||
"\n", | ||
"def save_item(item, file_name):\n", | ||
" with open(file_name, 'w') as f:\n", | ||
" json.dump(item, f, ensure_ascii=False, indent=4)\n", | ||
"\n", | ||
"def translate_save(item, i):\n", | ||
" if f\"translated_{i}.json\" in translated_files:\n", | ||
" return\n", | ||
" try:\n", | ||
" translated_item = translate_item(item)\n", | ||
" save_item(translated_item, f\"translated_data/data/translated_{i}.json\")\n", | ||
" except Exception as e:\n", | ||
" print(f\"translated_{i}.json: {e}\")\n", | ||
" with open(f\"translated_data/error/translated_{i}.json\", 'a'):\n", | ||
" pass" | ||
], | ||
"metadata": { | ||
"id": "UGJhLnbDcKm2" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Please upload alpaca_data.json\n", | ||
"with open('alpaca_data.json', 'r') as f:\n", | ||
" data = json.load(f)\n" | ||
], | ||
"metadata": { | ||
"id": "KxRPToBUdaet" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Translate in parallel\n", | ||
"with ThreadPoolExecutor(max_workers=100) as executor:\n", | ||
" futures = {executor.submit(translate_save, item, i) for i, item in enumerate(data)}\n", | ||
" \n", | ||
" for future in tqdm(as_completed(futures), total=len(futures), desc=\"Translating\"):\n", | ||
" future.result()\n" | ||
], | ||
"metadata": { | ||
"id": "jRiOYhlPdode" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Since it doesn't succeed in one attempt, it is necessary to execute multiple times. Please try untill all files are translated. It took US $45 and 5 hours." | ||
], | ||
"metadata": { | ||
"id": "BjnJ6ZJSehvb" | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## After finishing translation, merge those files into one file" | ||
], | ||
"metadata": { | ||
"id": "la267FnHeNqA" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def merge_json_files(data_folder):\n", | ||
" merged_data = []\n", | ||
" for i in range(52002):\n", | ||
" print(i)\n", | ||
" file_path = os.path.join(data_folder, f\"translated_{i}.json\")\n", | ||
" with open(file_path, 'r', encoding=\"utf-8\") as file:\n", | ||
" data = json.load(file)\n", | ||
" merged_data.append(data)\n", | ||
" return merged_data\n", | ||
"\n", | ||
"def write_merged_json_file(output_file, merged_data):\n", | ||
" with open(output_file, 'w', encoding=\"utf-8\") as file:\n", | ||
" json.dump(merged_data, file, indent=2, ensure_ascii=False)\n", | ||
"\n", | ||
"data_folder = 'translated_data/data'\n", | ||
"output_file = 'japanese_alpaca_data.json'\n", | ||
"\n", | ||
"merged_data = merge_json_files(data_folder)\n", | ||
"write_merged_json_file(output_file, merged_data)" | ||
], | ||
"metadata": { | ||
"id": "7MZC2pfNeUXN" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |