feat: add proper datagen

This commit is contained in:
2021-09-09 01:14:21 +09:00
parent fe6fed0759
commit dfe4980698
13 changed files with 1145 additions and 1126 deletions

View File

13
sensai_dataset/dataset.py Normal file
View File

@@ -0,0 +1,13 @@
from datasets.features import ClassLabel, Features, Value
from datasets.load import load_dataset
def load_sensai_dataset():
dataset = load_dataset(
"holodata/sensai",
features=Features({
"body": Value("string"),
"toxic": ClassLabel(num_classes=2, names=['0', '1'])
}))
dataset = dataset['train']
return dataset

View File

@@ -0,0 +1,16 @@
import argparse
from sensai_dataset.generator.commands import generate_dataset
from sensai_dataset.generator.constants import DATASET_DIR, DATASET_SOURCE_DIR
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='dataset generator')
parser.add_argument('-m', '--matcher', type=str, default='chats_*.csv')
args = parser.parse_args()
print('target: ' + DATASET_DIR)
print('source: ' + DATASET_SOURCE_DIR)
generate_dataset(source_dir=DATASET_SOURCE_DIR,
target_dir=DATASET_DIR,
matcher=args.matcher)

View File

@@ -0,0 +1,91 @@
import gc
from glob import iglob
from os.path import basename, join, splitext
import pandas as pd
def generate_dataset(source_dir, target_dir, matcher):
print('[generate_sensai_dataset]')
delet_path = join(source_dir, 'deletion_events.csv')
del_events = pd.read_csv(delet_path, usecols=['id', 'retracted'])
del_events = del_events.query('retracted == 0').copy()
del_events.drop(columns=['retracted'], inplace=True)
del_events['label'] = 'toxic'
ban_path = join(source_dir, 'ban_events.csv')
ban_events = pd.read_csv(ban_path, usecols=['authorChannelId', 'videoId'])
ban_events['label'] = 'spam'
for f in sorted(iglob(join(source_dir, matcher))):
period_string = splitext(basename(f))[0].split('_')[1]
print('>>> Period:', period_string)
# load chat
print('>>> Loading chats')
chat_path = join(source_dir, 'chats_' + period_string + '.csv')
chats = pd.read_csv(chat_path,
na_values='',
keep_default_na=False,
usecols=[
'authorChannelId',
'videoId',
'id',
'body',
])
# remove NA
chats = chats[chats['body'].notna()]
# apply mods
print('>>> Merging bans')
chats = pd.merge(chats,
ban_events,
on=['authorChannelId', 'videoId'],
how='left')
# apply mods
print('>>> Merging deletion')
chats = pd.merge(chats, del_events, on='id', how='left')
# fill NA label
chats['label'].fillna('safe', inplace=True)
isFlagged = chats['label'] != 'safe'
flagged = chats[isFlagged].copy()
# to make balanced dataset
nbFlagged = flagged.shape[0]
if nbFlagged == 0:
continue
print('>>> Sampling nonflagged chats')
print('nbFlagged', nbFlagged)
nonflag = chats[~isFlagged].sample(nbFlagged)
print('>>> Writing dataset')
# NOTE: do not use categorical type with to_parquest. otherwise, it will be failed to load them with huggingface's Dataset
columns_to_delete = [
'authorChannelId',
'videoId',
'id',
]
flagged.drop(columns=columns_to_delete, inplace=True)
flagged.to_parquet(join(target_dir,
f'chats_flagged_{period_string}.parquet'),
index=False)
nonflag.drop(columns=columns_to_delete, inplace=True)
nonflag.to_parquet(join(target_dir,
f'chats_nonflag_{period_string}.parquet'),
index=False)
# free up memory
del nonflag
del flagged
del chats
gc.collect()

View File

@@ -0,0 +1,7 @@
import os
DATASET_DIR = os.environ['DATASET_DIR']
DATASET_SOURCE_DIR = os.environ['DATASET_SOURCE_DIR']
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(DATASET_SOURCE_DIR, exist_ok=True)