mirror of
https://github.com/holodata/sensai-dataset.git
synced 2025-03-15 12:00:32 +09:00
feat: add proper datagen
This commit is contained in:
parent
fe6fed0759
commit
dfe4980698
@ -1,5 +1,11 @@
|
||||
# Contribution Guide
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
## Generate dataset
|
||||
|
||||
```bash
|
||||
|
4
Makefile
4
Makefile
@ -1,7 +1,7 @@
|
||||
all: build upload
|
||||
|
||||
build:
|
||||
python3 -m sensai_gen.cli
|
||||
python3 -m sensai_dataset.gen
|
||||
|
||||
upload:
|
||||
kaggle datasets version -d -m "New version" --path $$DATASET_DIR
|
||||
kaggle datasets version -m "New version" --path $$DATASET_DIR
|
||||
|
62
README.md
62
README.md
@ -1,8 +1,10 @@
|
||||
# ❤️🩹 Sensai: Toxic Chat Dataset
|
||||
|
||||
Sensai is a dataset consists of live chats from all across Virtual YouTubers' live streams, ready for training toxic chat classification models.
|
||||
Sensai is a toxic chat dataset consists of live chats from Virtual YouTubers' live streams.
|
||||
|
||||
Download the dataset from [Kaggle Datasets](https://www.kaggle.com/uetchy/sensai) and join `#livechat-dataset` channel on [holodata Discord](https://holodata.org/discord) for discussions.
|
||||
Download the dataset from [Huggingface Hub](https://huggingface.co/datasets/holodata/sensai) or alternatively from [Kaggle Datasets](https://www.kaggle.com/uetchy/sensai).
|
||||
|
||||
Join `#livechat-dataset` channel on [holodata Discord](https://holodata.org/discord) for discussions.
|
||||
|
||||
## Provenance
|
||||
|
||||
@ -23,7 +25,7 @@ See [public notebooks](https://www.kaggle.com/uetchy/sensai/code) for ideas.
|
||||
| filename | summary | size |
|
||||
| ------------------------- | -------------------------------------------------------------- | -------- |
|
||||
| `chats_flagged_%Y-%m.csv` | Chats flagged as either deleted or banned by mods (3,100,000+) | ~ 400 MB |
|
||||
| `chats_nonflag_%Y-%m.csv` | Non-flagged chats (3,000,000+) | ~ 300 MB |
|
||||
| `chats_nonflag_%Y-%m.csv` | Non-flagged chats (3,100,000+) | ~ 300 MB |
|
||||
|
||||
To make it a balanced dataset, the number of `chats_nonflags` is adjusted (randomly sampled) to be the same as `chats_flagged`.
|
||||
Ban and deletion are equivalent to `markChatItemsByAuthorAsDeletedAction` and `markChatItemAsDeletedAction` respectively.
|
||||
@ -35,38 +37,42 @@ Ban and deletion are equivalent to `markChatItemsByAuthorAsDeletedAction` and `m
|
||||
| column | type | description |
|
||||
| --------------- | ------ | ---------------------------- |
|
||||
| body | string | chat message |
|
||||
| membership | string | membership status |
|
||||
| authorChannelId | string | anonymized author channel id |
|
||||
| channelId | string | source channel id |
|
||||
| label | enum | toxic,spam,safe |
|
||||
|
||||
#### Membership status
|
||||
## Usage
|
||||
|
||||
| value | duration |
|
||||
| ----------------- | ------------------------- |
|
||||
| unknown | Indistinguishable |
|
||||
| non-member | 0 |
|
||||
| less than 1 month | < 1 month |
|
||||
| 1 month | >= 1 month, < 2 months |
|
||||
| 2 months | >= 2 months, < 6 months |
|
||||
| 6 months | >= 6 months, < 12 months |
|
||||
| 1 year | >= 12 months, < 24 months |
|
||||
| 2 years | >= 24 months |
|
||||
|
||||
#### Pandas usage
|
||||
|
||||
Set `keep_default_na` to `False` and `na_values` to `''` in `read_csv`. Otherwise, chat message like `NA` would incorrectly be treated as NaN value.
|
||||
https://huggingface.co/docs/datasets/loading_datasets.html
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
from glob import iglob
|
||||
# $ pip3 install datasets
|
||||
from datasets import load_dataset, Features, ClassLabel, Value
|
||||
|
||||
flagged = pd.concat([
|
||||
pd.read_csv(f,
|
||||
na_values='',
|
||||
keep_default_na=False)
|
||||
for f in iglob('../input/sensai/chats_flagged_*.csv')
|
||||
],
|
||||
ignore_index=True)
|
||||
dataset = load_dataset("holodata/sensai",
|
||||
features=Features(
|
||||
{
|
||||
"body": Value("string"),
|
||||
"toxic": ClassLabel(num_classes=2, names=['0', '1'])
|
||||
}
|
||||
))
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["body"], padding="max_length", truncation=True)
|
||||
|
||||
tokenized_datasets = dataset['train'].shuffle().select(range(50000)).map(tokenize_function, batched=True)
|
||||
tokenized_datasets.rename_column_("toxic", "label")
|
||||
splitset = tokenized_datasets.train_test_split(0.2)
|
||||
training_args = TrainingArguments("test_trainer")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Consideration
|
||||
|
354
notebooks/playground.ipynb
vendored
Normal file
354
notebooks/playground.ipynb
vendored
Normal file
@ -0,0 +1,354 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"source": [
|
||||
"from datasets import load_dataset, Features\n",
|
||||
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n",
|
||||
"import os\n",
|
||||
"from os.path import join\n",
|
||||
"import pandas as pd\n",
|
||||
"from datasets import ClassLabel, Value\n",
|
||||
"\n",
|
||||
"# https://huggingface.co/docs/datasets/loading_datasets.html\n",
|
||||
"\n",
|
||||
"DATASET_DIR = os.environ['DATASET_DIR']"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"source": [
|
||||
"dataset = load_dataset(\"holodata/sensai\", features=Features(\n",
|
||||
" {\n",
|
||||
" \"body\": Value(\"string\"),\n",
|
||||
" \"toxic\": ClassLabel(num_classes=2, names=['0', '1'])\n",
|
||||
" }\n",
|
||||
" ))\n",
|
||||
"dataset = dataset['train']"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "704789c10f1e44ddbb83262b8a826eec"
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Using custom data configuration sensai-4d9ed81389161083\n",
|
||||
"Reusing dataset parquet (/home/uetchy/.cache/huggingface/datasets/parquet/sensai-4d9ed81389161083/0.0.0/9296ce43568b20d72ff8ff8ecbc821a16b68e9b8b7058805ef11f06e035f911a)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "8acdde0c4caa4d4698f97ef29993195b"
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"source": [
|
||||
"dataset.features"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'body': Value(dtype='string', id=None),\n",
|
||||
" 'toxic': ClassLabel(num_classes=2, names=['0', '1'], names_file=None, id=None)}"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 3
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"source": [
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n",
|
||||
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||||
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"source": [
|
||||
"samples = dataset.shuffle().select(range(50000))\n",
|
||||
"\n",
|
||||
"def tokenize_function(examples):\n",
|
||||
" return tokenizer(examples[\"body\"], padding=\"max_length\", truncation=True)\n",
|
||||
"\n",
|
||||
"tokenized_datasets = samples.map(tokenize_function, batched=True)\n",
|
||||
"tokenized_datasets.rename_column_(\"toxic\", \"label\")"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Loading cached shuffled indices for dataset at /home/uetchy/.cache/huggingface/datasets/parquet/sensai-4d9ed81389161083/0.0.0/9296ce43568b20d72ff8ff8ecbc821a16b68e9b8b7058805ef11f06e035f911a/cache-24e3dd769ef2f1b7.arrow\n",
|
||||
"Loading cached processed dataset at /home/uetchy/.cache/huggingface/datasets/parquet/sensai-4d9ed81389161083/0.0.0/9296ce43568b20d72ff8ff8ecbc821a16b68e9b8b7058805ef11f06e035f911a/cache-8395f066c72e57d7.arrow\n",
|
||||
"/tmp/ipykernel_4082765/2982913603.py:7: FutureWarning: rename_column_ is deprecated and will be removed in the next major version of datasets. Use Dataset.rename_column instead.\n",
|
||||
" tokenized_datasets.rename_column_(\"toxic\", \"label\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"source": [
|
||||
"splitset = tokenized_datasets.train_test_split(0.2)\n",
|
||||
"splitset"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DatasetDict({\n",
|
||||
" train: Dataset({\n",
|
||||
" features: ['attention_mask', 'body', 'input_ids', 'token_type_ids', 'label'],\n",
|
||||
" num_rows: 40000\n",
|
||||
" })\n",
|
||||
" test: Dataset({\n",
|
||||
" features: ['attention_mask', 'body', 'input_ids', 'token_type_ids', 'label'],\n",
|
||||
" num_rows: 10000\n",
|
||||
" })\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 6
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"source": [
|
||||
"training_args = TrainingArguments(\"test_trainer\")\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"source": [
|
||||
"trainer.train(resume_from_checkpoint=True)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Loading model from test_trainer/checkpoint-12500).\n",
|
||||
"The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: body.\n",
|
||||
"***** Running training *****\n",
|
||||
" Num examples = 40000\n",
|
||||
" Num Epochs = 3\n",
|
||||
" Instantaneous batch size per device = 8\n",
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = 8\n",
|
||||
" Gradient Accumulation steps = 1\n",
|
||||
" Total optimization steps = 15000\n",
|
||||
" Continuing training from checkpoint, will skip to saved global_step\n",
|
||||
" Continuing training from epoch 2\n",
|
||||
" Continuing training from global step 12500\n",
|
||||
" Will skip the first 2 epochs then the first 2500 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "d749dea90f914485916ce82c65e73fe6"
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed fashion, reproducibility is not guaranteed.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='15000' max='15000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [15000/15000 31:45, Epoch 3/3]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Step</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>13000</td>\n",
|
||||
" <td>0.687500</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>13500</td>\n",
|
||||
" <td>0.686300</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>14000</td>\n",
|
||||
" <td>0.637900</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>14500</td>\n",
|
||||
" <td>0.643200</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>15000</td>\n",
|
||||
" <td>0.627700</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Saving model checkpoint to test_trainer/checkpoint-13000\n",
|
||||
"Configuration saved in test_trainer/checkpoint-13000/config.json\n",
|
||||
"Model weights saved in test_trainer/checkpoint-13000/pytorch_model.bin\n",
|
||||
"Saving model checkpoint to test_trainer/checkpoint-13500\n",
|
||||
"Configuration saved in test_trainer/checkpoint-13500/config.json\n",
|
||||
"Model weights saved in test_trainer/checkpoint-13500/pytorch_model.bin\n",
|
||||
"Saving model checkpoint to test_trainer/checkpoint-14000\n",
|
||||
"Configuration saved in test_trainer/checkpoint-14000/config.json\n",
|
||||
"Model weights saved in test_trainer/checkpoint-14000/pytorch_model.bin\n",
|
||||
"Saving model checkpoint to test_trainer/checkpoint-14500\n",
|
||||
"Configuration saved in test_trainer/checkpoint-14500/config.json\n",
|
||||
"Model weights saved in test_trainer/checkpoint-14500/pytorch_model.bin\n",
|
||||
"Saving model checkpoint to test_trainer/checkpoint-15000\n",
|
||||
"Configuration saved in test_trainer/checkpoint-15000/config.json\n",
|
||||
"Model weights saved in test_trainer/checkpoint-15000/pytorch_model.bin\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TrainOutput(global_step=15000, training_loss=0.10941998901367188, metrics={'train_runtime': 1918.0916, 'train_samples_per_second': 62.562, 'train_steps_per_second': 7.82, 'total_flos': 3.24994775580672e+16, 'train_loss': 0.10941998901367188, 'epoch': 3.0})"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 9
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"orig_nbformat": 4,
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.8.6",
|
||||
"mimetype": "text/x-python",
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"pygments_lexer": "ipython3",
|
||||
"nbconvert_exporter": "python",
|
||||
"file_extension": ".py"
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.8.6 64-bit"
|
||||
},
|
||||
"interpreter": {
|
||||
"hash": "c8daecebaf2d81430b8373e4b4af380b12df116248cd1bbadd3fc947f45a1f88"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
1577
poetry.lock
generated
1577
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,22 +1,23 @@
|
||||
[tool.poetry]
|
||||
name = "sensai"
|
||||
name = "sensai-dataset"
|
||||
version = "1.0.0"
|
||||
description = "Toxic Live Chat Dataset"
|
||||
authors = ["Yasuaki Uechi <y@uechi.io>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
kaggle = "^1.5.12"
|
||||
numpy = "^1.21.0"
|
||||
pandas = "^1.2.3"
|
||||
pymongo = "^3.11.3"
|
||||
python-dateutil = "^2.8.1"
|
||||
altair = "^4.1.0"
|
||||
matplotlib = "^3.4.2"
|
||||
streamlit = "^0.87.0"
|
||||
plotly = "^5.0.0"
|
||||
python = "^3.8"
|
||||
datasets = {git = "https://github.com/huggingface/datasets.git", rev = "master"}
|
||||
#datasets = "^1.11.0"
|
||||
pyarrow = "^5.0.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pymongo = "^3.11.3"
|
||||
transformers = "^4.10.0"
|
||||
torch = "^1.9.0"
|
||||
tokenizers = "^0.10.3"
|
||||
kaggle = "^1.5.12"
|
||||
pandas = "^1.2.3"
|
||||
python-dateutil = "^2.8.1"
|
||||
ipykernel = "^6.2.0"
|
||||
yapf = "^0.31.0"
|
||||
|
||||
|
13
sensai_dataset/dataset.py
Normal file
13
sensai_dataset/dataset.py
Normal file
@ -0,0 +1,13 @@
|
||||
from datasets.features import ClassLabel, Features, Value
|
||||
from datasets.load import load_dataset
|
||||
|
||||
|
||||
def load_sensai_dataset():
|
||||
dataset = load_dataset(
|
||||
"holodata/sensai",
|
||||
features=Features({
|
||||
"body": Value("string"),
|
||||
"toxic": ClassLabel(num_classes=2, names=['0', '1'])
|
||||
}))
|
||||
dataset = dataset['train']
|
||||
return dataset
|
16
sensai_dataset/generator/__init__.py
Normal file
16
sensai_dataset/generator/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
import argparse
|
||||
|
||||
from sensai_dataset.generator.commands import generate_dataset
|
||||
from sensai_dataset.generator.constants import DATASET_DIR, DATASET_SOURCE_DIR
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='dataset generator')
|
||||
parser.add_argument('-m', '--matcher', type=str, default='chats_*.csv')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('target: ' + DATASET_DIR)
|
||||
print('source: ' + DATASET_SOURCE_DIR)
|
||||
|
||||
generate_dataset(source_dir=DATASET_SOURCE_DIR,
|
||||
target_dir=DATASET_DIR,
|
||||
matcher=args.matcher)
|
91
sensai_dataset/generator/commands.py
Normal file
91
sensai_dataset/generator/commands.py
Normal file
@ -0,0 +1,91 @@
|
||||
import gc
|
||||
from glob import iglob
|
||||
from os.path import basename, join, splitext
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def generate_dataset(source_dir, target_dir, matcher):
|
||||
print('[generate_sensai_dataset]')
|
||||
|
||||
delet_path = join(source_dir, 'deletion_events.csv')
|
||||
del_events = pd.read_csv(delet_path, usecols=['id', 'retracted'])
|
||||
del_events = del_events.query('retracted == 0').copy()
|
||||
del_events.drop(columns=['retracted'], inplace=True)
|
||||
del_events['label'] = 'toxic'
|
||||
|
||||
ban_path = join(source_dir, 'ban_events.csv')
|
||||
ban_events = pd.read_csv(ban_path, usecols=['authorChannelId', 'videoId'])
|
||||
ban_events['label'] = 'spam'
|
||||
|
||||
for f in sorted(iglob(join(source_dir, matcher))):
|
||||
period_string = splitext(basename(f))[0].split('_')[1]
|
||||
print('>>> Period:', period_string)
|
||||
|
||||
# load chat
|
||||
print('>>> Loading chats')
|
||||
chat_path = join(source_dir, 'chats_' + period_string + '.csv')
|
||||
|
||||
chats = pd.read_csv(chat_path,
|
||||
na_values='',
|
||||
keep_default_na=False,
|
||||
usecols=[
|
||||
'authorChannelId',
|
||||
'videoId',
|
||||
'id',
|
||||
'body',
|
||||
])
|
||||
|
||||
# remove NA
|
||||
chats = chats[chats['body'].notna()]
|
||||
|
||||
# apply mods
|
||||
print('>>> Merging bans')
|
||||
chats = pd.merge(chats,
|
||||
ban_events,
|
||||
on=['authorChannelId', 'videoId'],
|
||||
how='left')
|
||||
|
||||
# apply mods
|
||||
print('>>> Merging deletion')
|
||||
chats = pd.merge(chats, del_events, on='id', how='left')
|
||||
|
||||
# fill NA label
|
||||
chats['label'].fillna('safe', inplace=True)
|
||||
|
||||
isFlagged = chats['label'] != 'safe'
|
||||
flagged = chats[isFlagged].copy()
|
||||
|
||||
# to make balanced dataset
|
||||
nbFlagged = flagged.shape[0]
|
||||
if nbFlagged == 0:
|
||||
continue
|
||||
|
||||
print('>>> Sampling nonflagged chats')
|
||||
print('nbFlagged', nbFlagged)
|
||||
nonflag = chats[~isFlagged].sample(nbFlagged)
|
||||
|
||||
print('>>> Writing dataset')
|
||||
|
||||
# NOTE: do not use categorical type with to_parquest. otherwise, it will be failed to load them with huggingface's Dataset
|
||||
columns_to_delete = [
|
||||
'authorChannelId',
|
||||
'videoId',
|
||||
'id',
|
||||
]
|
||||
|
||||
flagged.drop(columns=columns_to_delete, inplace=True)
|
||||
flagged.to_parquet(join(target_dir,
|
||||
f'chats_flagged_{period_string}.parquet'),
|
||||
index=False)
|
||||
|
||||
nonflag.drop(columns=columns_to_delete, inplace=True)
|
||||
nonflag.to_parquet(join(target_dir,
|
||||
f'chats_nonflag_{period_string}.parquet'),
|
||||
index=False)
|
||||
|
||||
# free up memory
|
||||
del nonflag
|
||||
del flagged
|
||||
del chats
|
||||
gc.collect()
|
@ -4,4 +4,4 @@ DATASET_DIR = os.environ['DATASET_DIR']
|
||||
DATASET_SOURCE_DIR = os.environ['DATASET_SOURCE_DIR']
|
||||
|
||||
os.makedirs(DATASET_DIR, exist_ok=True)
|
||||
os.makedirs(DATASET_SOURCE_DIR, exist_ok=True)
|
||||
os.makedirs(DATASET_SOURCE_DIR, exist_ok=True)
|
@ -1,123 +0,0 @@
|
||||
import gc
|
||||
from glob import iglob
|
||||
import argparse
|
||||
import shutil
|
||||
from os.path import basename, join, splitext
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from sensai_gen.constants import DATASET_DIR, DATASET_SOURCE_DIR
|
||||
|
||||
|
||||
def load_channels(**kwargs):
|
||||
dtype_dict = {
|
||||
'channelId': 'category',
|
||||
'name': 'category',
|
||||
'englishName': 'category',
|
||||
'affiliation': 'category',
|
||||
'group': 'category',
|
||||
'subscriptionCount': 'int32',
|
||||
'videoCount': 'int32',
|
||||
'photo': 'category'
|
||||
}
|
||||
channels = pd.read_csv(join(DATASET_SOURCE_DIR, 'channels.csv'),
|
||||
dtype=dtype_dict,
|
||||
**kwargs)
|
||||
return channels
|
||||
|
||||
|
||||
def generate_dataset(matcher):
|
||||
print('[generate_sensai_dataset]')
|
||||
|
||||
delet_path = join(DATASET_SOURCE_DIR, 'deletion_events.csv')
|
||||
del_events = pd.read_csv(delet_path, usecols=['id', 'retracted'])
|
||||
del_events = del_events.query('retracted == 0').copy()
|
||||
del_events.drop(columns=['retracted'], inplace=True)
|
||||
del_events['deleted'] = True
|
||||
|
||||
ban_path = join(DATASET_SOURCE_DIR, 'ban_events.csv')
|
||||
ban_events = pd.read_csv(ban_path, usecols=['authorChannelId', 'videoId'])
|
||||
ban_events['banned'] = True
|
||||
|
||||
for f in sorted(iglob(join(DATASET_SOURCE_DIR, matcher))):
|
||||
period_string = splitext(basename(f))[0].split('_')[1]
|
||||
print('>>> Period:', period_string)
|
||||
|
||||
columns_to_use = [
|
||||
'body',
|
||||
'authorChannelId',
|
||||
'channelId',
|
||||
'membership',
|
||||
'id',
|
||||
'videoId',
|
||||
]
|
||||
columns_to_delete = [
|
||||
'id',
|
||||
'videoId',
|
||||
'deleted',
|
||||
'banned',
|
||||
]
|
||||
|
||||
# load chat
|
||||
print('>>> Loading chats')
|
||||
chat_path = join(DATASET_SOURCE_DIR, 'chats_' + period_string + '.csv')
|
||||
chat_dtype = {
|
||||
'authorChannelId': 'category',
|
||||
'membership': 'category',
|
||||
'videoId': 'category',
|
||||
'channelId': 'category'
|
||||
}
|
||||
chats = pd.read_csv(chat_path, dtype=chat_dtype, usecols=columns_to_use)
|
||||
|
||||
# apply mods
|
||||
print('>>> Merging deletion')
|
||||
chats = pd.merge(chats, del_events, on='id', how='left')
|
||||
chats['deleted'].fillna(False, inplace=True)
|
||||
|
||||
# apply mods
|
||||
print('>>> Merging bans')
|
||||
chats = pd.merge(chats,
|
||||
ban_events,
|
||||
on=['authorChannelId', 'videoId'],
|
||||
how='left')
|
||||
chats['banned'].fillna(False, inplace=True)
|
||||
|
||||
flagged = chats[(chats['deleted'] | chats['banned'])].copy()
|
||||
|
||||
# to make balanced dataset
|
||||
nbFlagged = flagged.shape[0]
|
||||
if nbFlagged == 0:
|
||||
continue
|
||||
|
||||
print('>>> Sampling nonflagged chats')
|
||||
print('nbFlagged', nbFlagged)
|
||||
nonflag = chats[~(chats['deleted'] | chats['banned'])].sample(nbFlagged)
|
||||
|
||||
print('>>> Writing dataset')
|
||||
|
||||
flagged.drop(columns=columns_to_delete, inplace=True)
|
||||
flagged.to_csv(join(DATASET_DIR, f'chats_flagged_{period_string}.csv'),
|
||||
index=False)
|
||||
nonflag.drop(columns=columns_to_delete, inplace=True)
|
||||
nonflag.to_csv(join(DATASET_DIR, f'chats_nonflag_{period_string}.csv'),
|
||||
index=False)
|
||||
|
||||
# free up memory
|
||||
del nonflag
|
||||
del flagged
|
||||
del chats
|
||||
gc.collect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='dataset generator')
|
||||
parser.add_argument('-m', '--matcher', type=str, default='chats_*.csv')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('target: ' + DATASET_DIR)
|
||||
print('source: ' + DATASET_SOURCE_DIR)
|
||||
|
||||
shutil.copy(join(DATASET_SOURCE_DIR, 'channels.csv'), DATASET_DIR)
|
||||
|
||||
generate_dataset(matcher=args.matcher)
|
Loading…
x
Reference in New Issue
Block a user