mirror of
https://github.com/holodata/sensai-dataset.git
synced 2025-08-20 11:18:12 +09:00
feat: add proper datagen
This commit is contained in:
0
sensai_dataset/__init__.py
Normal file
0
sensai_dataset/__init__.py
Normal file
13
sensai_dataset/dataset.py
Normal file
13
sensai_dataset/dataset.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from datasets.features import ClassLabel, Features, Value
|
||||
from datasets.load import load_dataset
|
||||
|
||||
|
||||
def load_sensai_dataset():
|
||||
dataset = load_dataset(
|
||||
"holodata/sensai",
|
||||
features=Features({
|
||||
"body": Value("string"),
|
||||
"toxic": ClassLabel(num_classes=2, names=['0', '1'])
|
||||
}))
|
||||
dataset = dataset['train']
|
||||
return dataset
|
16
sensai_dataset/generator/__init__.py
Normal file
16
sensai_dataset/generator/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import argparse
|
||||
|
||||
from sensai_dataset.generator.commands import generate_dataset
|
||||
from sensai_dataset.generator.constants import DATASET_DIR, DATASET_SOURCE_DIR
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='dataset generator')
|
||||
parser.add_argument('-m', '--matcher', type=str, default='chats_*.csv')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('target: ' + DATASET_DIR)
|
||||
print('source: ' + DATASET_SOURCE_DIR)
|
||||
|
||||
generate_dataset(source_dir=DATASET_SOURCE_DIR,
|
||||
target_dir=DATASET_DIR,
|
||||
matcher=args.matcher)
|
91
sensai_dataset/generator/commands.py
Normal file
91
sensai_dataset/generator/commands.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import gc
|
||||
from glob import iglob
|
||||
from os.path import basename, join, splitext
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def generate_dataset(source_dir, target_dir, matcher):
|
||||
print('[generate_sensai_dataset]')
|
||||
|
||||
delet_path = join(source_dir, 'deletion_events.csv')
|
||||
del_events = pd.read_csv(delet_path, usecols=['id', 'retracted'])
|
||||
del_events = del_events.query('retracted == 0').copy()
|
||||
del_events.drop(columns=['retracted'], inplace=True)
|
||||
del_events['label'] = 'toxic'
|
||||
|
||||
ban_path = join(source_dir, 'ban_events.csv')
|
||||
ban_events = pd.read_csv(ban_path, usecols=['authorChannelId', 'videoId'])
|
||||
ban_events['label'] = 'spam'
|
||||
|
||||
for f in sorted(iglob(join(source_dir, matcher))):
|
||||
period_string = splitext(basename(f))[0].split('_')[1]
|
||||
print('>>> Period:', period_string)
|
||||
|
||||
# load chat
|
||||
print('>>> Loading chats')
|
||||
chat_path = join(source_dir, 'chats_' + period_string + '.csv')
|
||||
|
||||
chats = pd.read_csv(chat_path,
|
||||
na_values='',
|
||||
keep_default_na=False,
|
||||
usecols=[
|
||||
'authorChannelId',
|
||||
'videoId',
|
||||
'id',
|
||||
'body',
|
||||
])
|
||||
|
||||
# remove NA
|
||||
chats = chats[chats['body'].notna()]
|
||||
|
||||
# apply mods
|
||||
print('>>> Merging bans')
|
||||
chats = pd.merge(chats,
|
||||
ban_events,
|
||||
on=['authorChannelId', 'videoId'],
|
||||
how='left')
|
||||
|
||||
# apply mods
|
||||
print('>>> Merging deletion')
|
||||
chats = pd.merge(chats, del_events, on='id', how='left')
|
||||
|
||||
# fill NA label
|
||||
chats['label'].fillna('safe', inplace=True)
|
||||
|
||||
isFlagged = chats['label'] != 'safe'
|
||||
flagged = chats[isFlagged].copy()
|
||||
|
||||
# to make balanced dataset
|
||||
nbFlagged = flagged.shape[0]
|
||||
if nbFlagged == 0:
|
||||
continue
|
||||
|
||||
print('>>> Sampling nonflagged chats')
|
||||
print('nbFlagged', nbFlagged)
|
||||
nonflag = chats[~isFlagged].sample(nbFlagged)
|
||||
|
||||
print('>>> Writing dataset')
|
||||
|
||||
# NOTE: do not use categorical type with to_parquest. otherwise, it will be failed to load them with huggingface's Dataset
|
||||
columns_to_delete = [
|
||||
'authorChannelId',
|
||||
'videoId',
|
||||
'id',
|
||||
]
|
||||
|
||||
flagged.drop(columns=columns_to_delete, inplace=True)
|
||||
flagged.to_parquet(join(target_dir,
|
||||
f'chats_flagged_{period_string}.parquet'),
|
||||
index=False)
|
||||
|
||||
nonflag.drop(columns=columns_to_delete, inplace=True)
|
||||
nonflag.to_parquet(join(target_dir,
|
||||
f'chats_nonflag_{period_string}.parquet'),
|
||||
index=False)
|
||||
|
||||
# free up memory
|
||||
del nonflag
|
||||
del flagged
|
||||
del chats
|
||||
gc.collect()
|
7
sensai_dataset/generator/constants.py
Normal file
7
sensai_dataset/generator/constants.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import os
|
||||
|
||||
DATASET_DIR = os.environ['DATASET_DIR']
|
||||
DATASET_SOURCE_DIR = os.environ['DATASET_SOURCE_DIR']
|
||||
|
||||
os.makedirs(DATASET_DIR, exist_ok=True)
|
||||
os.makedirs(DATASET_SOURCE_DIR, exist_ok=True)
|
Reference in New Issue
Block a user