mirror of
https://github.com/holodata/sensai-dataset.git
synced 2025-03-15 12:00:32 +09:00
3.2 KiB
3.2 KiB
In [ ]:
from datasets import load_dataset, Features
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import os
from datasets import ClassLabel, Value
# https://huggingface.co/docs/datasets/loading_datasets.html
DATASET_DIR = os.environ.get("DATASET_DIR", "../input/sensai")
print("DATASET_DIR", DATASET_DIR)
In [ ]:
dataset = load_dataset("holodata/sensai", features=Features(
{
"body": Value("string"),
"toxic": ClassLabel(num_classes=2, names=['0', '1'])
}
))
dataset = dataset['train']
In [ ]:
dataset.features
In [ ]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
In [ ]:
samples = dataset.shuffle().select(range(50000))
def tokenize_function(examples):
return tokenizer(examples["body"], padding="max_length", truncation=True)
tokenized_datasets = samples.map(tokenize_function, batched=True)
tokenized_datasets.rename_column_("toxic", "label")
In [ ]:
splitset = tokenized_datasets.train_test_split(0.2)
splitset
In [ ]:
training_args = TrainingArguments("test_trainer")
trainer = Trainer(
model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']
)
In [ ]:
trainer.train(resume_from_checkpoint=True)
In [ ]: