{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset, Features\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n", "import os\n", "from datasets import ClassLabel, Value\n", "\n", "# https://huggingface.co/docs/datasets/loading_datasets.html\n", "\n", "DATASET_DIR = os.environ.get(\"DATASET_DIR\", \"../input/sensai\")\n", "print(\"DATASET_DIR\", DATASET_DIR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"holodata/sensai\", features=Features(\n", " {\n", " \"body\": Value(\"string\"),\n", " \"toxic\": ClassLabel(num_classes=2, names=['0', '1'])\n", " }\n", " ))\n", "dataset = dataset['train']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "samples = dataset.shuffle().select(range(50000))\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"body\"], padding=\"max_length\", truncation=True)\n", "\n", "tokenized_datasets = samples.map(tokenize_function, batched=True)\n", "tokenized_datasets.rename_column_(\"toxic\", \"label\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splitset = tokenized_datasets.train_test_split(0.2)\n", "splitset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\"test_trainer\")\n", "trainer = Trainer(\n", " model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train(resume_from_checkpoint=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "c8daecebaf2d81430b8373e4b4af380b12df116248cd1bbadd3fc947f45a1f88" }, "kernelspec": { "display_name": "Python 3.8.6 64-bit", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }