mirror of
https://github.com/holodata/sensai-dataset.git
synced 2025-03-15 12:00:32 +09:00
34 KiB
34 KiB
In [1]:
import os
from os.path import join
from glob import glob
import pandas as pd
from sentence_transformers import SentenceTransformer
import tangram
DATASET_DIR = os.environ.get("DATASET_DIR", "../input/sensai")
print("DATASET_DIR", DATASET_DIR)
In [2]:
# load tokenizer
st = SentenceTransformer('all-MiniLM-L6-v2')
In [3]:
print('Loading data...')
df = pd.concat(
[pd.read_parquet(x) for x in glob(join(DATASET_DIR, '*.parquet'))],
ignore_index=True)
In [8]:
print('Sampling data...')
# Chats with body size larger than 1
# df = df[df['body'].str.len() > 1]
# Merge labels
df['label'] = df['label'].apply(lambda x: {
'deleted': 'toxic',
'hidden': 'toxic',
'nonflagged': 'safe'
}[x])
# Drop duplicates
df = df.drop_duplicates()
df.info()
# Balance item count for each category
g = df.groupby('label')
sample_size = min(g.size().min(), 4_000_000)
print('sample_size', sample_size)
df = g.apply(lambda x: x.sample(sample_size)).reset_index(drop=True)
df.info()
print(df.groupby('label').size())
df.sample(5)
Out[8]:
In [9]:
print('Formatting data...')
# Add counting features
# df['blen'] = df['body'].str.len()
# Concat author name and message with a separator
df['body'] = df.apply(lambda x: (x['authorName'] or '<EMPTY>') + '<SEP>' + x['body'], axis=1)
df.drop(columns=['authorName'], inplace=True)
df.info()
df.sample(5)
Out[9]:
In [13]:
print('Calculating embedding vectors...')
emb = pd.DataFrame(st.encode(df['body'].to_list(), convert_to_numpy=True))
df = pd.concat([df, emb], axis=1).drop(columns=['body'])
df.head()
Out[13]:
In [14]:
df.info()
df.to_csv('./data/input_7m_balanced_deduped_amc_MiniLM.csv', index=False)
In [ ]:
!tangram train -f ./data/input_7m_balanced_deduped_amc_MiniLM.csv -o ./models/sensai_7m_balanced_deduped_amc_MiniLM.tangram -t label
In [4]:
model = tangram.Model.from_path('./models/sensai_7m_balanced_deduped_amc_MiniLM.tangram')
def remap_label(pred) -> str:
label = '👍'
if pred.class_name == 'toxic':
if pred.probability <= 0.6:
label = '⚖️'
else:
label = '🚽'
return label
def predict_batch(texts):
preds = model.predict([{str(k): v for k,v in enumerate(x)} for x in st.encode(texts)])
return [(remap_label(pred), pred.probability) for pred in preds]
def predict(text: str):
return predict_batch([text])[0]
In [ ]:
chats = [
"lol",
"草",
"かわいい!!!", # cute !!!
"INAAAAAAAAAA",
"[ENG] Heh ? don't I have way less recovery items ? don't they come back on each continue ????",
"What's with this sassy lost child?",
"Your village looks so cute and pretty",
"imagine the dead team member watching him dance for no reason",
"A型だけど、気にならないどころかむしろ安心感ある",
"Tenchou pls even if you don't think your singing is as good, you are still one of the most electric and infectious souls I've ever seen. Ganbatte bossu!!!",
"Nice butt",
"My dog died let's goooooooooooo",
"when will you graduate?",
"lol you guys are pathetic",
"死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね", # die die die die die die
"卒業卒業卒業卒業", # graduate graduate graduate graduate
]
for text, pred in zip(chats, predict_batch(chats)):
print(pred[0], ':', text)
In [5]:
deleted_by_mods = df[df['label'] != 'nonflagged']
safe = df[df['label'] == 'nonflagged']
print('# of deleted chats', len(deleted_by_mods))
print('# of safe chats', len(safe))
In [6]:
N = 50000
def adjust_label(pred):
if pred.class_name == 'safe':
if pred.probability <= 0.5:
return 'toxic'
else:
return 'safe'
return 'toxic'
def predict_labels(chats):
return [(x.class_name, x.probability) for x in model.predict([{str(k): v for k,v in enumerate(x)} for x in st.encode(chats.apply(lambda x: (x['authorName'] or '<EMPTY>') + '<SEP>' + x['body'], axis=1).to_list())])]
deleted_labels = predict_labels(deleted_by_mods.sample(N))
safe_labels = predict_labels(safe.sample(N))
In [7]:
def f_beta(prec, rec, beta):
return (1+beta**2) * ( (prec * rec) / ( beta**2 * prec + rec ) )
fn = len([x for x in deleted_labels if x[0] == 'safe'])
fp = len([x for x in safe_labels if x[0] != 'safe'])
tp = N - fn
tn = N - fp
pp = tp + fp
fnr = fn / N
fpr = fp / N
sensitivity = tp / N
precision = tp / pp
accuracy = (tp + tn) / (N*2)
fb = f_beta(precision, sensitivity, 2)
print('N', len(deleted_labels)+len(safe_labels))
print('FPR', fpr)
print('FNR', fnr)
print('Sensitivity', sensitivity)
print('Precision', precision)
print('Accuracy', accuracy)
print('F(2) score', fb)
In [23]:
predict_labels(pd.DataFrame.from_dict([{'authorName': 'Skamor', 'body': '味ってわかる?"あじ"って読むんやで'}]))
Out[23]:
In [26]:
pd.DataFrame.from_dict([{'authorName': 'Hi', 'body': 'Hi'}])
Out[26]:
In [111]:
import json
def convertRawMessageToString(rawMessage):
def handler(run):
msgType = list(run.keys())[0]
payload = run[msgType]
if msgType == 'text':
return payload
elif msgType == 'emoji':
if 'isCustomEmoji' in payload:
return "\uFFFD"
else:
return payload['emojiId']
else:
raise 'Invalid type: ' + msgType + ', ' + payload
return "".join([handler(run) for run in rawMessage])
with open('addChatItemAction.jsonl') as f:
lines = list(map(json.loads, f.readlines()))
chats = pd.DataFrame.from_dict(lines)
chats = chats[['authorName', 'message']]
chats['message'] = chats['message'].apply(convertRawMessageToString)
chats = chats.rename(columns={'message': 'body'})
chats['label'] = pd.Series(map(lambda x: 'hold' if x[1] < 0.85 else {'safe': 'safe', 'toxic': 'annoying'}[x[0]], predict_labels(chats)))
In [106]:
chats[(chats['label'] == 'safe') & (chats['body'].str.len() > 3)].sample(20)
Out[106]: