{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DATASET_DIR /home/uetchy/repos/src/github.com/holodata/sensai-huggingface\n" ] } ], "source": [ "import os\n", "from os.path import join\n", "from glob import glob\n", "import pandas as pd\n", "from sentence_transformers import SentenceTransformer\n", "import tangram\n", "\n", "DATASET_DIR = os.environ.get(\"DATASET_DIR\", \"../input/sensai\")\n", "print(\"DATASET_DIR\", DATASET_DIR)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "# load tokenizer\n", "st = SentenceTransformer('all-MiniLM-L6-v2')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data...\n" ] } ], "source": [ "print('Loading data...')\n", "df = pd.concat(\n", " [pd.read_parquet(x) for x in glob(join(DATASET_DIR, '*.parquet'))],\n", " ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sampling data...\n", "\n", "Int64Index: 9121058 entries, 0 to 12942624\n", "Data columns (total 3 columns):\n", " # Column Dtype \n", "--- ------ ----- \n", " 0 authorName object\n", " 1 body object\n", " 2 label object\n", "dtypes: object(3)\n", "memory usage: 278.4+ MB\n", "sample_size 3647717\n", "\n", "RangeIndex: 7295434 entries, 0 to 7295433\n", "Data columns (total 3 columns):\n", " # Column Dtype \n", "--- ------ ----- \n", " 0 authorName object\n", " 1 body object\n", " 2 label object\n", "dtypes: object(3)\n", "memory usage: 167.0+ MB\n", "label\n", "safe 3647717\n", "toxic 3647717\n", "dtype: int64\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
authorNamebodylabel
2646589それもえナイスーsafe
6722778LassieOWK U S A K E N ? ? ? K U S A K E N ? ? ? K U S...toxic
4490817ヤルダバオトここに猫左下toxic
522319IkeShamus5safe
6359809susi Susantikeigo tsunderetoxic
\n", "
" ], "text/plain": [ " authorName body \\\n", "2646589 それもえ ナイスー \n", "6722778 LassieOW K U S A K E N ? ? ? K U S A K E N ? ? ? K U S... \n", "4490817 ヤルダバオトここに 猫左下 \n", "522319 IkeShamus 5 \n", "6359809 susi Susanti keigo tsundere \n", "\n", " label \n", "2646589 safe \n", "6722778 toxic \n", "4490817 toxic \n", "522319 safe \n", "6359809 toxic " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Sampling data...')\n", "\n", "# Chats with body size larger than 1\n", "# df = df[df['body'].str.len() > 1]\n", "\n", "# Merge labels\n", "df['label'] = df['label'].apply(lambda x: {\n", " 'deleted': 'toxic',\n", " 'hidden': 'toxic',\n", " 'nonflagged': 'safe'\n", "}[x])\n", "\n", "# Drop duplicates\n", "df = df.drop_duplicates()\n", "df.info()\n", "\n", "# Balance item count for each category\n", "g = df.groupby('label')\n", "sample_size = min(g.size().min(), 4_000_000)\n", "print('sample_size', sample_size)\n", "df = g.apply(lambda x: x.sample(sample_size)).reset_index(drop=True)\n", "\n", "df.info()\n", "print(df.groupby('label').size())\n", "df.sample(5)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Formatting data...\n", "\n", "RangeIndex: 7295434 entries, 0 to 7295433\n", "Data columns (total 2 columns):\n", " # Column Dtype \n", "--- ------ ----- \n", " 0 body object\n", " 1 label object\n", "dtypes: object(2)\n", "memory usage: 111.3+ MB\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
bodylabel
6998597allormist<SEP>demon slayer copied this ideatoxic
2536692Rei<SEP>ああああああああsafe
3713136Alex<SEP>Look at the first part of the Quest. ...toxic
5614434Nov4Zyz<SEP>かわいいtoxic
517705Aida<SEP>lmaoooosafe
\n", "
" ], "text/plain": [ " body label\n", "6998597 allormistdemon slayer copied this idea toxic\n", "2536692 Reiああああああああ safe\n", "3713136 AlexLook at the first part of the Quest. ... toxic\n", "5614434 Nov4Zyzかわいい toxic\n", "517705 Aidalmaoooo safe" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Formatting data...')\n", "\n", "# Add counting features\n", "# df['blen'] = df['body'].str.len()\n", "\n", "# Concat author name and message with a separator\n", "df['body'] = df.apply(lambda x: (x['authorName'] or '') + '' + x['body'], axis=1)\n", "df.drop(columns=['authorName'], inplace=True)\n", "\n", "df.info()\n", "df.sample(5)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating embedding vectors...\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
label012345678...374375376377378379380381382383
0safe0.0096120.0524240.083845-0.0315950.0288930.0590270.045335-0.0143430.037504...-0.0882240.0372810.0204760.0674150.000324-0.0495880.0063970.1391690.049218-0.025690
1safe0.0033280.1142130.032118-0.055552-0.0083580.0243220.0804830.0076640.062874...-0.0003770.0469010.018417-0.003170-0.0274690.0056390.1335000.0322210.030415-0.019664
2safe0.031809-0.0240300.026037-0.039061-0.0027860.0525650.126942-0.0145870.047931...-0.0145790.0092630.0103710.045266-0.076779-0.0236780.0213640.036397-0.001076-0.060540
3safe-0.0618740.0460600.0215200.0217780.0091270.0212920.1842210.004618-0.069447...0.0001470.0020030.0894950.0403570.0073210.0160960.0123400.0624340.0058100.003983
4safe-0.053528-0.0459180.076154-0.0032610.005603-0.0588070.0716980.0050380.010522...0.077442-0.080605-0.0383630.013772-0.049518-0.0047300.142463-0.0178700.013901-0.037920
\n", "

5 rows × 385 columns

\n", "
" ], "text/plain": [ " label 0 1 2 3 4 5 6 \\\n", "0 safe 0.009612 0.052424 0.083845 -0.031595 0.028893 0.059027 0.045335 \n", "1 safe 0.003328 0.114213 0.032118 -0.055552 -0.008358 0.024322 0.080483 \n", "2 safe 0.031809 -0.024030 0.026037 -0.039061 -0.002786 0.052565 0.126942 \n", "3 safe -0.061874 0.046060 0.021520 0.021778 0.009127 0.021292 0.184221 \n", "4 safe -0.053528 -0.045918 0.076154 -0.003261 0.005603 -0.058807 0.071698 \n", "\n", " 7 8 ... 374 375 376 377 378 \\\n", "0 -0.014343 0.037504 ... -0.088224 0.037281 0.020476 0.067415 0.000324 \n", "1 0.007664 0.062874 ... -0.000377 0.046901 0.018417 -0.003170 -0.027469 \n", "2 -0.014587 0.047931 ... -0.014579 0.009263 0.010371 0.045266 -0.076779 \n", "3 0.004618 -0.069447 ... 0.000147 0.002003 0.089495 0.040357 0.007321 \n", "4 0.005038 0.010522 ... 0.077442 -0.080605 -0.038363 0.013772 -0.049518 \n", "\n", " 379 380 381 382 383 \n", "0 -0.049588 0.006397 0.139169 0.049218 -0.025690 \n", "1 0.005639 0.133500 0.032221 0.030415 -0.019664 \n", "2 -0.023678 0.021364 0.036397 -0.001076 -0.060540 \n", "3 0.016096 0.012340 0.062434 0.005810 0.003983 \n", "4 -0.004730 0.142463 -0.017870 0.013901 -0.037920 \n", "\n", "[5 rows x 385 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Calculating embedding vectors...')\n", "\n", "emb = pd.DataFrame(st.encode(df['body'].to_list(), convert_to_numpy=True))\n", "df = pd.concat([df, emb], axis=1).drop(columns=['body'])\n", "\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 7295434 entries, 0 to 7295433\n", "Columns: 385 entries, label to 383\n", "dtypes: float32(384), object(1)\n", "memory usage: 10.5+ GB\n" ] } ], "source": [ "df.info()\n", "df.to_csv('./data/input_7m_balanced_deduped_amc_MiniLM.csv', index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!tangram train -f ./data/input_7m_balanced_deduped_amc_MiniLM.csv -o ./models/sensai_7m_balanced_deduped_amc_MiniLM.tangram -t label" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model = tangram.Model.from_path('./models/sensai_7m_balanced_deduped_amc_MiniLM.tangram')\n", "\n", "def remap_label(pred) -> str:\n", " label = '👍'\n", " if pred.class_name == 'toxic':\n", " if pred.probability <= 0.6:\n", " label = '⚖️'\n", " else:\n", " label = '🚽'\n", " return label\n", "\n", "\n", "def predict_batch(texts):\n", " preds = model.predict([{str(k): v for k,v in enumerate(x)} for x in st.encode(texts)])\n", " return [(remap_label(pred), pred.probability) for pred in preds]\n", "\n", "def predict(text: str):\n", " return predict_batch([text])[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "👍 : lol\n", "🚽 : 草\n", "👍 : かわいい!!!\n", "👍 : INAAAAAAAAAA\n", "🚽 : [ENG] Heh ? don't I have way less recovery items ? don't they come back on each continue ????\n", "👍 : What's with this sassy lost child?\n", "👍 : Your village looks so cute and pretty\n", "⚖️ : imagine the dead team member watching him dance for no reason\n", "👍 : A型だけど、気にならないどころかむしろ安心感ある\n", "🚽 : Tenchou pls even if you don't think your singing is as good, you are still one of the most electric and infectious souls I've ever seen. Ganbatte bossu!!!\n", "🚽 : Nice butt\n", "🚽 : My dog died let's goooooooooooo\n", "👍 : when will you graduate?\n", "⚖️ : lol you guys are pathetic\n", "🚽 : 死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね\n", "🚽 : 卒業卒業卒業卒業\n" ] } ], "source": [ "chats = [\n", " \"lol\",\n", " \"草\",\n", " \"かわいい!!!\", # cute !!!\n", " \"INAAAAAAAAAA\",\n", " \"[ENG] Heh ? don't I have way less recovery items ? don't they come back on each continue ????\",\n", " \"What's with this sassy lost child?\",\n", " \"Your village looks so cute and pretty\",\n", " \"imagine the dead team member watching him dance for no reason\",\n", " \"A型だけど、気にならないどころかむしろ安心感ある\",\n", " \"Tenchou pls even if you don't think your singing is as good, you are still one of the most electric and infectious souls I've ever seen. Ganbatte bossu!!!\",\n", " \"Nice butt\",\n", " \"My dog died let's goooooooooooo\",\n", " \"when will you graduate?\",\n", " \"lol you guys are pathetic\",\n", " \"死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね死ね\", # die die die die die die\n", " \"卒業卒業卒業卒業\", # graduate graduate graduate graduate\n", "]\n", "\n", "for text, pred in zip(chats, predict_batch(chats)):\n", " print(pred[0], ':', text)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# of deleted chats 6471313\n", "# of safe chats 6471313\n" ] } ], "source": [ "deleted_by_mods = df[df['label'] != 'nonflagged']\n", "safe = df[df['label'] == 'nonflagged']\n", "print('# of deleted chats', len(deleted_by_mods))\n", "print('# of safe chats', len(safe))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "N = 50000\n", "\n", "def adjust_label(pred):\n", " if pred.class_name == 'safe':\n", " if pred.probability <= 0.5:\n", " return 'toxic'\n", " else:\n", " return 'safe'\n", " return 'toxic'\n", "\n", "def predict_labels(chats):\n", " return [(x.class_name, x.probability) for x in model.predict([{str(k): v for k,v in enumerate(x)} for x in st.encode(chats.apply(lambda x: (x['authorName'] or '') + '' + x['body'], axis=1).to_list())])]\n", "\n", "deleted_labels = predict_labels(deleted_by_mods.sample(N))\n", "safe_labels = predict_labels(safe.sample(N))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "N 100000\n", "FPR 0.14654\n", "FNR 0.181\n", "Sensitivity 0.819\n", "Precision 0.8482300060070013\n", "Accuracy 0.83623\n", "F(2) score 0.8246837201996158\n" ] } ], "source": [ "def f_beta(prec, rec, beta):\n", " return (1+beta**2) * ( (prec * rec) / ( beta**2 * prec + rec ) )\n", "\n", "fn = len([x for x in deleted_labels if x[0] == 'safe'])\n", "fp = len([x for x in safe_labels if x[0] != 'safe'])\n", "tp = N - fn\n", "tn = N - fp\n", "pp = tp + fp\n", "fnr = fn / N\n", "fpr = fp / N\n", "sensitivity = tp / N\n", "precision = tp / pp\n", "accuracy = (tp + tn) / (N*2)\n", "fb = f_beta(precision, sensitivity, 2)\n", "print('N', len(deleted_labels)+len(safe_labels))\n", "print('FPR', fpr)\n", "print('FNR', fnr)\n", "print('Sensitivity', sensitivity)\n", "print('Precision', precision)\n", "print('Accuracy', accuracy)\n", "print('F(2) score', fb)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('safe', 0.5598429441452026)]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict_labels(pd.DataFrame.from_dict([{'authorName': 'Skamor', 'body': '味ってわかる?\"あじ\"って読むんやで'}]))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
authorNamebody
0HiHi
\n", "
" ], "text/plain": [ " authorName body\n", "0 Hi Hi" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame.from_dict([{'authorName': 'Hi', 'body': 'Hi'}])" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "def convertRawMessageToString(rawMessage):\n", " def handler(run):\n", " msgType = list(run.keys())[0]\n", " payload = run[msgType]\n", " if msgType == 'text':\n", " return payload\n", " elif msgType == 'emoji':\n", " if 'isCustomEmoji' in payload:\n", " return \"\\uFFFD\"\n", " else:\n", " return payload['emojiId']\n", " else:\n", " raise 'Invalid type: ' + msgType + ', ' + payload\n", " return \"\".join([handler(run) for run in rawMessage])\n", "\n", "with open('addChatItemAction.jsonl') as f:\n", " lines = list(map(json.loads, f.readlines()))\n", "chats = pd.DataFrame.from_dict(lines)\n", "chats = chats[['authorName', 'message']]\n", "chats['message'] = chats['message'].apply(convertRawMessageToString)\n", "chats = chats.rename(columns={'message': 'body'})\n", "chats['label'] = pd.Series(map(lambda x: 'hold' if x[1] < 0.85 else {'safe': 'safe', 'toxic': 'annoying'}[x[0]], predict_labels(chats)))" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
authorNamebodylabel
3690Miguel Islaslmaosafe
451maminㅋㅋㅋㅋㅋㅋsafe
7977Pierah🧡💛💙💜❤safe
1993Kuga RaianLMAOOOO😳😳😳safe
5752BOY tentenこんにちはーsafe
9304linkis20💙💙💙💙💙safe
10540MegaSlayer92💙💙💙💙💙💙safe
5562BrandermauYou're welcomesafe
8890零羅かわいいsafe
8155Kronii's Apostle [8th of the Twelve]ayy pog momentsafe
7414linkis20💙❤🧡💜💛safe
7976LongtimeNoc-KFP-NephamilyIs happy to make one of your dreams finally co...safe
10969Bishop Johnson💙💙💙💙safe
7504Arthur❤🧡💛💙💜safe
9343Nova StarYou did greatsafe
544BreakingPhobia103FEET�safe
11250Oliver ChenLOOOOOOOOOLsafe
2791Kronii's Apostle [8th of the Twelve]is mumei really that beeg?safe
9991linkis20💙💙💙💙💙safe
7843»pingpong💙💛💜❤🧡safe
\n", "
" ], "text/plain": [ " authorName \\\n", "3690 Miguel Islas \n", "451 mamin \n", "7977 Pierah \n", "1993 Kuga Raian \n", "5752 BOY tenten \n", "9304 linkis20 \n", "10540 MegaSlayer92 \n", "5562 Brandermau \n", "8890 零羅 \n", "8155 Kronii's Apostle [8th of the Twelve] \n", "7414 linkis20 \n", "7976 LongtimeNoc-KFP-Nephamily \n", "10969 Bishop Johnson \n", "7504 Arthur \n", "9343 Nova Star \n", "544 BreakingPhobia103 \n", "11250 Oliver Chen \n", "2791 Kronii's Apostle [8th of the Twelve] \n", "9991 linkis20 \n", "7843 »pingpong \n", "\n", " body label \n", "3690 lmao safe \n", "451 ㅋㅋㅋㅋㅋㅋ safe \n", "7977 🧡💛💙💜❤ safe \n", "1993 LMAOOOO😳😳😳 safe \n", "5752 こんにちはー safe \n", "9304 💙💙💙💙💙 safe \n", "10540 💙💙💙💙💙💙 safe \n", "5562 You're welcome safe \n", "8890 かわいい safe \n", "8155 ayy pog moment safe \n", "7414 💙❤🧡💜💛 safe \n", "7976 Is happy to make one of your dreams finally co... safe \n", "10969 💙💙💙💙 safe \n", "7504 ❤🧡💛💙💜 safe \n", "9343 You did great safe \n", "544 FEET� safe \n", "11250 LOOOOOOOOOL safe \n", "2791 is mumei really that beeg? safe \n", "9991 💙💙💙💙💙 safe \n", "7843 💙💛💜❤🧡 safe " ] }, "execution_count": 106, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chats[(chats['label'] == 'safe') & (chats['body'].str.len() > 3)].sample(20)" ] } ], "metadata": { "interpreter": { "hash": "24403c88b9bd347a0b41a9fc3e3175e2948f6e2f45c79c9df260f2a479a817d3" }, "kernelspec": { "display_name": "Python 3.8.6 64-bit ('.venv': poetry)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }