mirror of
https://github.com/holodata/sensai-dataset.git
synced 2025-03-15 12:00:32 +09:00
133 lines
3.2 KiB
Plaintext
133 lines
3.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from datasets import load_dataset, Features\n",
|
|
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n",
|
|
"import os\n",
|
|
"from datasets import ClassLabel, Value\n",
|
|
"\n",
|
|
"# https://huggingface.co/docs/datasets/loading_datasets.html\n",
|
|
"\n",
|
|
"DATASET_DIR = os.environ.get(\"DATASET_DIR\", \"../input/sensai\")\n",
|
|
"print(\"DATASET_DIR\", DATASET_DIR)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset = load_dataset(\"holodata/sensai\", features=Features(\n",
|
|
" {\n",
|
|
" \"body\": Value(\"string\"),\n",
|
|
" \"toxic\": ClassLabel(num_classes=2, names=['0', '1'])\n",
|
|
" }\n",
|
|
" ))\n",
|
|
"dataset = dataset['train']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset.features"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"samples = dataset.shuffle().select(range(50000))\n",
|
|
"\n",
|
|
"def tokenize_function(examples):\n",
|
|
" return tokenizer(examples[\"body\"], padding=\"max_length\", truncation=True)\n",
|
|
"\n",
|
|
"tokenized_datasets = samples.map(tokenize_function, batched=True)\n",
|
|
"tokenized_datasets.rename_column_(\"toxic\", \"label\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"splitset = tokenized_datasets.train_test_split(0.2)\n",
|
|
"splitset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"training_args = TrainingArguments(\"test_trainer\")\n",
|
|
"trainer = Trainer(\n",
|
|
" model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"trainer.train(resume_from_checkpoint=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"interpreter": {
|
|
"hash": "c8daecebaf2d81430b8373e4b4af380b12df116248cd1bbadd3fc947f45a1f88"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3.8.6 64-bit",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.6"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|