sensai/examples/tangram.ipynb

34 KiB

In [1]:
import os
from os.path import join
from glob import glob
import pandas as pd
from sentence_transformers import SentenceTransformer
import tangram

DATASET_DIR = os.environ.get("DATASET_DIR", "../input/sensai")
print("DATASET_DIR", DATASET_DIR)
DATASET_DIR /home/uetchy/repos/src/github.com/holodata/sensai-huggingface
In [2]:
# load tokenizer
st = SentenceTransformer('all-MiniLM-L6-v2')
In [3]:
print('Loading data...')
df = pd.concat(
    [pd.read_parquet(x) for x in glob(join(DATASET_DIR, '*.parquet'))],
    ignore_index=True)
Loading data...
In [8]:
print('Sampling data...')

# Chats with body size larger than 1
# df = df[df['body'].str.len() > 1]

# Merge labels
df['label'] = df['label'].apply(lambda x: {
    'deleted': 'toxic',
    'hidden': 'toxic',
    'nonflagged': 'safe'
}[x])

# Drop duplicates
df = df.drop_duplicates()
df.info()

# Balance item count for each category
g = df.groupby('label')
sample_size = min(g.size().min(), 4_000_000)
print('sample_size', sample_size)
df = g.apply(lambda x: x.sample(sample_size)).reset_index(drop=True)

df.info()
print(df.groupby('label').size())
df.sample(5)
Sampling data...
<class 'pandas.core.frame.DataFrame'>
Int64Index: 9121058 entries, 0 to 12942624
Data columns (total 3 columns):
 #   Column      Dtype 
---  ------      ----- 
 0   authorName  object
 1   body        object
 2   label       object
dtypes: object(3)
memory usage: 278.4+ MB
sample_size 3647717
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7295434 entries, 0 to 7295433
Data columns (total 3 columns):
 #   Column      Dtype 
---  ------      ----- 
 0   authorName  object
 1   body        object
 2   label       object
dtypes: object(3)
memory usage: 167.0+ MB
label
safe     3647717
toxic    3647717
dtype: int64
Out[8]:
authorName body label
2646589 それもえ ナイスー safe
6722778 LassieOW K U S A K E N ? ? ? K U S A K E N ? ? ? K U S... toxic
4490817 ヤルダバオトここに 猫左下 toxic
522319 IkeShamus 5 safe
6359809 susi Susanti keigo tsundere toxic
In [9]:
print('Formatting data...')

# Add counting features
# df['blen'] = df['body'].str.len()

# Concat author name and message with a separator
df['body'] = df.apply(lambda x: (x['authorName'] or '<EMPTY>') + '<SEP>' + x['body'], axis=1)
df.drop(columns=['authorName'], inplace=True)

df.info()
df.sample(5)
Formatting data...
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7295434 entries, 0 to 7295433
Data columns (total 2 columns):
 #   Column  Dtype 
---  ------  ----- 
 0   body    object
 1   label   object
dtypes: object(2)
memory usage: 111.3+ MB
Out[9]:
body label
6998597 allormist<SEP>demon slayer copied this idea toxic
2536692 Rei<SEP>ああああああああ safe
3713136 Alex<SEP>Look at the first part of the Quest. ... toxic
5614434 Nov4Zyz<SEP>かわいい toxic
517705 Aida<SEP>lmaoooo safe
In [13]:
print('Calculating embedding vectors...')

emb = pd.DataFrame(st.encode(df['body'].to_list(), convert_to_numpy=True))
df = pd.concat([df, emb], axis=1).drop(columns=['body'])

df.head()
Calculating embedding vectors...
Out[13]:
label 0 1 2 3 4 5 6 7 8 ... 374 375 376 377 378 379 380 381 382 383
0 safe 0.009612 0.052424 0.083845 -0.031595 0.028893 0.059027 0.045335 -0.014343 0.037504 ... -0.088224 0.037281 0.020476 0.067415 0.000324 -0.049588 0.006397 0.139169 0.049218 -0.025690
1 safe 0.003328 0.114213 0.032118 -0.055552 -0.008358 0.024322 0.080483 0.007664 0.062874 ... -0.000377 0.046901 0.018417 -0.003170 -0.027469 0.005639 0.133500 0.032221 0.030415 -0.019664
2 safe 0.031809 -0.024030 0.026037 -0.039061 -0.002786 0.052565 0.126942 -0.014587 0.047931 ... -0.014579 0.009263 0.010371 0.045266 -0.076779 -0.023678 0.021364 0.036397 -0.001076 -0.060540
3 safe -0.061874 0.046060 0.021520 0.021778 0.009127 0.021292 0.184221 0.004618 -0.069447 ... 0.000147 0.002003 0.089495 0.040357 0.007321 0.016096 0.012340 0.062434 0.005810 0.003983
4 safe -0.053528 -0.045918 0.076154 -0.003261 0.005603 -0.058807 0.071698 0.005038 0.010522 ... 0.077442 -0.080605 -0.038363 0.013772 -0.049518 -0.004730 0.142463 -0.017870 0.013901 -0.037920

5 rows × 385 columns

In [14]:
df.info()
df.to_csv('./data/input_7m_balanced_deduped_amc_MiniLM.csv', index=False)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7295434 entries, 0 to 7295433
Columns: 385 entries, label to 383
dtypes: float32(384), object(1)
memory usage: 10.5+ GB
In [ ]:
!tangram train -f ./data/input_7m_balanced_deduped_amc_MiniLM.csv -o ./models/sensai_7m_balanced_deduped_amc_MiniLM.tangram -t label
In [4]:
model = tangram.Model.from_path('./models/sensai_7m_balanced_deduped_amc_MiniLM.tangram')

def remap_label(pred) -> str:
    label = '👍'
    if pred.class_name == 'toxic':
        if pred.probability <= 0.6:
            label = '⚖️'
        else:
            label = '🚽'
    return label


def predict_batch(texts):
    preds = model.predict([{str(k): v for k,v in enumerate(x)} for x in st.encode(texts)])
    return [(remap_label(pred), pred.probability) for pred in preds]

def predict(text: str):
    return predict_batch([text])[0]
In [ ]:
chats = [
    "lol",
    "草",
    "かわいい!!!", # cute !!!
    "INAAAAAAAAAA",
    "[ENG] Heh ? don't I have way less recovery items ? don't they come back on each continue ????",
    "What's with this sassy lost child?",
    "Your village looks so cute and pretty",
    "imagine the dead team member watching him dance for no reason",
    "A型だけど、気にならないどころかむしろ安心感ある",
    "Tenchou pls even if you don't think your singing is as good, you are still one of the most electric and infectious souls I've ever seen. Ganbatte bossu!!!",
    "Nice butt",
    "My dog died let's goooooooooooo",
    "when will you graduate?",
    "lol you guys are pathetic",
    "死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね", # die die die die die die
    "卒業卒業卒業卒業", # graduate graduate graduate graduate
]

for text, pred in zip(chats, predict_batch(chats)):
    print(pred[0], ':', text)
👍 : lol
🚽 : 草
👍 : かわいい!!!
👍 : INAAAAAAAAAA
🚽 : [ENG] Heh ? don't I have way less recovery items ? don't they come back on each continue ????
👍 : What's with this sassy lost child?
👍 : Your village looks so cute and pretty
⚖️ : imagine the dead team member watching him dance for no reason
👍 : A型だけど、気にならないどころかむしろ安心感ある
🚽 : Tenchou pls even if you don't think your singing is as good, you are still one of the most electric and infectious souls I've ever seen. Ganbatte bossu!!!
🚽 : Nice butt
🚽 : My dog died let's goooooooooooo
👍 : when will you graduate?
⚖️ : lol you guys are pathetic
🚽 : 死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね
🚽 : 卒業卒業卒業卒業
In [5]:
deleted_by_mods = df[df['label'] != 'nonflagged']
safe = df[df['label'] == 'nonflagged']
print('# of deleted chats', len(deleted_by_mods))
print('# of safe chats', len(safe))
# of deleted chats 6471313
# of safe chats 6471313
In [6]:
N = 50000

def adjust_label(pred):
    if pred.class_name == 'safe':
        if pred.probability <= 0.5:
            return 'toxic'
        else:
            return 'safe'
    return 'toxic'

def predict_labels(chats):
    return [(x.class_name, x.probability) for x in model.predict([{str(k): v for k,v in enumerate(x)} for x in st.encode(chats.apply(lambda x: (x['authorName'] or '<EMPTY>') + '<SEP>' + x['body'], axis=1).to_list())])]

deleted_labels = predict_labels(deleted_by_mods.sample(N))
safe_labels = predict_labels(safe.sample(N))
In [7]:
def f_beta(prec, rec, beta):
    return (1+beta**2) * ( (prec * rec) / ( beta**2 * prec + rec ) )

fn = len([x for x in deleted_labels if x[0] == 'safe'])
fp = len([x for x in safe_labels if x[0] != 'safe'])
tp = N - fn
tn = N - fp
pp = tp + fp
fnr = fn / N
fpr = fp / N
sensitivity = tp / N
precision = tp / pp
accuracy = (tp + tn) / (N*2)
fb = f_beta(precision, sensitivity, 2)
print('N', len(deleted_labels)+len(safe_labels))
print('FPR', fpr)
print('FNR', fnr)
print('Sensitivity', sensitivity)
print('Precision', precision)
print('Accuracy', accuracy)
print('F(2) score', fb)
N 100000
FPR 0.14654
FNR 0.181
Sensitivity 0.819
Precision 0.8482300060070013
Accuracy 0.83623
F(2) score 0.8246837201996158
In [23]:
predict_labels(pd.DataFrame.from_dict([{'authorName': 'Skamor', 'body': '味ってわかる?"あじ"って読むんやで'}]))
Out[23]:
[('safe', 0.5598429441452026)]
In [26]:
pd.DataFrame.from_dict([{'authorName': 'Hi', 'body': 'Hi'}])
Out[26]:
authorName body
0 Hi Hi
In [111]:
import json

def convertRawMessageToString(rawMessage):
    def handler(run):
        msgType = list(run.keys())[0]
        payload = run[msgType]
        if msgType == 'text':
            return payload
        elif msgType == 'emoji':
            if 'isCustomEmoji' in payload:
                return "\uFFFD"
            else:
                return payload['emojiId']
        else:
            raise 'Invalid type: ' + msgType + ', ' + payload
    return "".join([handler(run) for run in rawMessage])

with open('addChatItemAction.jsonl') as f:
    lines = list(map(json.loads, f.readlines()))
chats = pd.DataFrame.from_dict(lines)
chats = chats[['authorName', 'message']]
chats['message'] = chats['message'].apply(convertRawMessageToString)
chats = chats.rename(columns={'message': 'body'})
chats['label'] = pd.Series(map(lambda x: 'hold' if x[1] < 0.85 else {'safe': 'safe', 'toxic': 'annoying'}[x[0]], predict_labels(chats)))
In [106]:
chats[(chats['label'] == 'safe') & (chats['body'].str.len() > 3)].sample(20)
Out[106]:
authorName body label
3690 Miguel Islas lmao safe
451 mamin ㅋㅋㅋㅋㅋㅋ safe
7977 Pierah 🧡💛💙💜 safe
1993 Kuga Raian LMAOOOO😳😳😳 safe
5752 BOY tenten こんにちはー safe
9304 linkis20 💙💙💙💙💙 safe
10540 MegaSlayer92 💙💙💙💙💙💙 safe
5562 Brandermau You're welcome safe
8890 零羅 かわいい safe
8155 Kronii's Apostle [8th of the Twelve] ayy pog moment safe
7414 linkis20 💙🧡💜💛 safe
7976 LongtimeNoc-KFP-Nephamily Is happy to make one of your dreams finally co... safe
10969 Bishop Johnson 💙💙💙💙 safe
7504 Arthur 🧡💛💙💜 safe
9343 Nova Star You did great safe
544 BreakingPhobia103 FEET� safe
11250 Oliver Chen LOOOOOOOOOL safe
2791 Kronii's Apostle [8th of the Twelve] is mumei really that beeg? safe
9991 linkis20 💙💙💙💙💙 safe
7843 »pingpong 💙💛💜🧡 safe