sensai/examples/transformers.ipynb

133 lines
3.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset, Features\n",
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n",
"import os\n",
"from datasets import ClassLabel, Value\n",
"\n",
"# https://huggingface.co/docs/datasets/loading_datasets.html\n",
"\n",
"DATASET_DIR = os.environ.get(\"DATASET_DIR\", \"../input/sensai\")\n",
"print(\"DATASET_DIR\", DATASET_DIR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset(\"holodata/sensai\", features=Features(\n",
" {\n",
" \"body\": Value(\"string\"),\n",
" \"toxic\": ClassLabel(num_classes=2, names=['0', '1'])\n",
" }\n",
" ))\n",
"dataset = dataset['train']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset.features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"samples = dataset.shuffle().select(range(50000))\n",
"\n",
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"body\"], padding=\"max_length\", truncation=True)\n",
"\n",
"tokenized_datasets = samples.map(tokenize_function, batched=True)\n",
"tokenized_datasets.rename_column_(\"toxic\", \"label\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"splitset = tokenized_datasets.train_test_split(0.2)\n",
"splitset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\"test_trainer\")\n",
"trainer = Trainer(\n",
" model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train(resume_from_checkpoint=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "c8daecebaf2d81430b8373e4b4af380b12df116248cd1bbadd3fc947f45a1f88"
},
"kernelspec": {
"display_name": "Python 3.8.6 64-bit",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}