mirror of
				https://github.com/holodata/sensai-dataset.git
				synced 2025-11-04 09:08:57 +09:00 
			
		
		
		
	refactor: add more examples
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitattributes
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitattributes
									
									
									
									
										vendored
									
									
								
							@@ -1 +1 @@
 | 
				
			|||||||
notebooks/*.ipynb linguist-vendored
 | 
					notebooks/*.ipynb linguist-vendored
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,6 +1,9 @@
 | 
				
			|||||||
.envrc
 | 
					.envrc
 | 
				
			||||||
.env
 | 
					.env
 | 
				
			||||||
.vscode
 | 
					.vscode
 | 
				
			||||||
 | 
					TODO
 | 
				
			||||||
 | 
					*.csv
 | 
				
			||||||
 | 
					*.tangram
 | 
				
			||||||
/tmp
 | 
					/tmp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Created by https://www.toptal.com/developers/gitignore/api/python
 | 
					# Created by https://www.toptal.com/developers/gitignore/api/python
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							@@ -1,7 +1,7 @@
 | 
				
			|||||||
all: build upload
 | 
					all: build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
build:
 | 
					build:
 | 
				
			||||||
	python3 -m sensai_dataset.gen
 | 
						python3 -m sensai_dataset.generator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
upload:
 | 
					upload:
 | 
				
			||||||
	kaggle datasets version -m "New version" --path $$DATASET_DIR
 | 
						kaggle datasets version -m "New version" --path $$DATASET_DIR
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										20
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								README.md
									
									
									
									
									
								
							@@ -32,7 +32,7 @@ Ban and deletion are equivalent to `markChatItemsByAuthorAsDeletedAction` and `m
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
## Dataset Breakdown
 | 
					## Dataset Breakdown
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Chats (`chats_%Y-%m.csv`)
 | 
					### Chats (`chats_%Y-%m.parquet`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| column          | type   | description                   |
 | 
					| column          | type   | description                   |
 | 
				
			||||||
| --------------- | ------ | ----------------------------- |
 | 
					| --------------- | ------ | ----------------------------- |
 | 
				
			||||||
@@ -43,6 +43,17 @@ Ban and deletion are equivalent to `markChatItemsByAuthorAsDeletedAction` and `m
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
## Usage
 | 
					## Usage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Pandas
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					from glob import glob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					df = pd.concat([pd.read_parquet(x) for x in glob('../input/sensai/*.parquet')], ignore_index=True)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Huggingface Transformers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
https://huggingface.co/docs/datasets/loading_datasets.html
 | 
					https://huggingface.co/docs/datasets/loading_datasets.html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
@@ -75,6 +86,13 @@ trainer = Trainer(
 | 
				
			|||||||
trainer.train()
 | 
					trainer.train()
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Tangram
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python3 ./examples/prepare_tangram_dataset.py
 | 
				
			||||||
 | 
					tangram train --file ./tangram_input.csv --target label
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Consideration
 | 
					## Consideration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Anonymization
 | 
					### Anonymization
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										11341
									
								
								examples/addChatItemAction.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11341
									
								
								examples/addChatItemAction.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										164
									
								
								examples/pandas.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								examples/pandas.ipynb
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,164 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					 "cells": [
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": 2,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					     "name": "stdout",
 | 
				
			||||||
 | 
					     "output_type": "stream",
 | 
				
			||||||
 | 
					     "text": [
 | 
				
			||||||
 | 
					      "DATASET_DIR /home/uetchy/repos/src/github.com/holodata/sensai-huggingface\n"
 | 
				
			||||||
 | 
					     ]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					   ],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "import pandas as pd\n",
 | 
				
			||||||
 | 
					    "from os.path import join\n",
 | 
				
			||||||
 | 
					    "import os\n",
 | 
				
			||||||
 | 
					    "from glob import glob\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "DATASET_DIR = os.environ.get(\"DATASET_DIR\", \"../input/sensai\")\n",
 | 
				
			||||||
 | 
					    "print(\"DATASET_DIR\", DATASET_DIR)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": 3,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					     "name": "stdout",
 | 
				
			||||||
 | 
					     "output_type": "stream",
 | 
				
			||||||
 | 
					     "text": [
 | 
				
			||||||
 | 
					      "<class 'pandas.core.frame.DataFrame'>\n",
 | 
				
			||||||
 | 
					      "RangeIndex: 10677038 entries, 0 to 10677037\n",
 | 
				
			||||||
 | 
					      "Data columns (total 2 columns):\n",
 | 
				
			||||||
 | 
					      " #   Column  Dtype \n",
 | 
				
			||||||
 | 
					      "---  ------  ----- \n",
 | 
				
			||||||
 | 
					      " 0   body    object\n",
 | 
				
			||||||
 | 
					      " 1   label   object\n",
 | 
				
			||||||
 | 
					      "dtypes: object(2)\n",
 | 
				
			||||||
 | 
					      "memory usage: 162.9+ MB\n"
 | 
				
			||||||
 | 
					     ]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					   ],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "df = pd.concat(\n",
 | 
				
			||||||
 | 
					    "    [pd.read_parquet(x) for x in glob(join(DATASET_DIR, '*.parquet'))],\n",
 | 
				
			||||||
 | 
					    "    ignore_index=True)\n",
 | 
				
			||||||
 | 
					    "df.info()"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": 11,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					     "data": {
 | 
				
			||||||
 | 
					      "text/html": [
 | 
				
			||||||
 | 
					       "<div>\n",
 | 
				
			||||||
 | 
					       "<style scoped>\n",
 | 
				
			||||||
 | 
					       "    .dataframe tbody tr th:only-of-type {\n",
 | 
				
			||||||
 | 
					       "        vertical-align: middle;\n",
 | 
				
			||||||
 | 
					       "    }\n",
 | 
				
			||||||
 | 
					       "\n",
 | 
				
			||||||
 | 
					       "    .dataframe tbody tr th {\n",
 | 
				
			||||||
 | 
					       "        vertical-align: top;\n",
 | 
				
			||||||
 | 
					       "    }\n",
 | 
				
			||||||
 | 
					       "\n",
 | 
				
			||||||
 | 
					       "    .dataframe thead th {\n",
 | 
				
			||||||
 | 
					       "        text-align: right;\n",
 | 
				
			||||||
 | 
					       "    }\n",
 | 
				
			||||||
 | 
					       "</style>\n",
 | 
				
			||||||
 | 
					       "<table border=\"1\" class=\"dataframe\">\n",
 | 
				
			||||||
 | 
					       "  <thead>\n",
 | 
				
			||||||
 | 
					       "    <tr style=\"text-align: right;\">\n",
 | 
				
			||||||
 | 
					       "      <th></th>\n",
 | 
				
			||||||
 | 
					       "      <th>body</th>\n",
 | 
				
			||||||
 | 
					       "      <th>label</th>\n",
 | 
				
			||||||
 | 
					       "    </tr>\n",
 | 
				
			||||||
 | 
					       "  </thead>\n",
 | 
				
			||||||
 | 
					       "  <tbody>\n",
 | 
				
			||||||
 | 
					       "    <tr>\n",
 | 
				
			||||||
 | 
					       "      <th>6229407</th>\n",
 | 
				
			||||||
 | 
					       "      <td>Blessed stream</td>\n",
 | 
				
			||||||
 | 
					       "      <td>hidden</td>\n",
 | 
				
			||||||
 | 
					       "    </tr>\n",
 | 
				
			||||||
 | 
					       "    <tr>\n",
 | 
				
			||||||
 | 
					       "      <th>7406071</th>\n",
 | 
				
			||||||
 | 
					       "      <td>RIP</td>\n",
 | 
				
			||||||
 | 
					       "      <td>nonflagged</td>\n",
 | 
				
			||||||
 | 
					       "    </tr>\n",
 | 
				
			||||||
 | 
					       "    <tr>\n",
 | 
				
			||||||
 | 
					       "      <th>920434</th>\n",
 | 
				
			||||||
 | 
					       "      <td>cute</td>\n",
 | 
				
			||||||
 | 
					       "      <td>nonflagged</td>\n",
 | 
				
			||||||
 | 
					       "    </tr>\n",
 | 
				
			||||||
 | 
					       "    <tr>\n",
 | 
				
			||||||
 | 
					       "      <th>6146625</th>\n",
 | 
				
			||||||
 | 
					       "      <td>GACHA lets gooo</td>\n",
 | 
				
			||||||
 | 
					       "      <td>hidden</td>\n",
 | 
				
			||||||
 | 
					       "    </tr>\n",
 | 
				
			||||||
 | 
					       "    <tr>\n",
 | 
				
			||||||
 | 
					       "      <th>8259711</th>\n",
 | 
				
			||||||
 | 
					       "      <td>草</td>\n",
 | 
				
			||||||
 | 
					       "      <td>hidden</td>\n",
 | 
				
			||||||
 | 
					       "    </tr>\n",
 | 
				
			||||||
 | 
					       "  </tbody>\n",
 | 
				
			||||||
 | 
					       "</table>\n",
 | 
				
			||||||
 | 
					       "</div>"
 | 
				
			||||||
 | 
					      ],
 | 
				
			||||||
 | 
					      "text/plain": [
 | 
				
			||||||
 | 
					       "                    body       label\n",
 | 
				
			||||||
 | 
					       "6229407   Blessed stream      hidden\n",
 | 
				
			||||||
 | 
					       "7406071              RIP  nonflagged\n",
 | 
				
			||||||
 | 
					       "920434              cute  nonflagged\n",
 | 
				
			||||||
 | 
					       "6146625  GACHA lets gooo      hidden\n",
 | 
				
			||||||
 | 
					       "8259711                草      hidden"
 | 
				
			||||||
 | 
					      ]
 | 
				
			||||||
 | 
					     },
 | 
				
			||||||
 | 
					     "execution_count": 11,
 | 
				
			||||||
 | 
					     "metadata": {},
 | 
				
			||||||
 | 
					     "output_type": "execute_result"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					   ],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "df.sample(5)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": []
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					 ],
 | 
				
			||||||
 | 
					 "metadata": {
 | 
				
			||||||
 | 
					  "interpreter": {
 | 
				
			||||||
 | 
					   "hash": "24403c88b9bd347a0b41a9fc3e3175e2948f6e2f45c79c9df260f2a479a817d3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "kernelspec": {
 | 
				
			||||||
 | 
					   "display_name": "Python 3.8.6 64-bit ('.venv': poetry)",
 | 
				
			||||||
 | 
					   "name": "python3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "language_info": {
 | 
				
			||||||
 | 
					   "codemirror_mode": {
 | 
				
			||||||
 | 
					    "name": "ipython",
 | 
				
			||||||
 | 
					    "version": 3
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "file_extension": ".py",
 | 
				
			||||||
 | 
					   "mimetype": "text/x-python",
 | 
				
			||||||
 | 
					   "name": "python",
 | 
				
			||||||
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
 | 
					   "version": "3.8.6"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "orig_nbformat": 4
 | 
				
			||||||
 | 
					 },
 | 
				
			||||||
 | 
					 "nbformat": 4,
 | 
				
			||||||
 | 
					 "nbformat_minor": 2
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										1069
									
								
								examples/tangram.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1069
									
								
								examples/tangram.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										1
									
								
								examples/toxic-chat-classification-using-lstm.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								examples/toxic-chat-classification-using-lstm.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										132
									
								
								examples/transformers.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								examples/transformers.ipynb
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,132 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					 "cells": [
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "from datasets import load_dataset, Features\n",
 | 
				
			||||||
 | 
					    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n",
 | 
				
			||||||
 | 
					    "import os\n",
 | 
				
			||||||
 | 
					    "from datasets import ClassLabel, Value\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "# https://huggingface.co/docs/datasets/loading_datasets.html\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "DATASET_DIR = os.environ.get(\"DATASET_DIR\", \"../input/sensai\")\n",
 | 
				
			||||||
 | 
					    "print(\"DATASET_DIR\", DATASET_DIR)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "dataset = load_dataset(\"holodata/sensai\", features=Features(\n",
 | 
				
			||||||
 | 
					    "                {\n",
 | 
				
			||||||
 | 
					    "                    \"body\": Value(\"string\"),\n",
 | 
				
			||||||
 | 
					    "                    \"toxic\": ClassLabel(num_classes=2, names=['0', '1'])\n",
 | 
				
			||||||
 | 
					    "                }\n",
 | 
				
			||||||
 | 
					    "            ))\n",
 | 
				
			||||||
 | 
					    "dataset = dataset['train']"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "dataset.features"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n",
 | 
				
			||||||
 | 
					    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "samples = dataset.shuffle().select(range(50000))\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "def tokenize_function(examples):\n",
 | 
				
			||||||
 | 
					    "    return tokenizer(examples[\"body\"], padding=\"max_length\", truncation=True)\n",
 | 
				
			||||||
 | 
					    "\n",
 | 
				
			||||||
 | 
					    "tokenized_datasets = samples.map(tokenize_function, batched=True)\n",
 | 
				
			||||||
 | 
					    "tokenized_datasets.rename_column_(\"toxic\", \"label\")"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "splitset = tokenized_datasets.train_test_split(0.2)\n",
 | 
				
			||||||
 | 
					    "splitset"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "training_args = TrainingArguments(\"test_trainer\")\n",
 | 
				
			||||||
 | 
					    "trainer = Trainer(\n",
 | 
				
			||||||
 | 
					    "    model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']\n",
 | 
				
			||||||
 | 
					    ")"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "trainer.train(resume_from_checkpoint=True)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": []
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					 ],
 | 
				
			||||||
 | 
					 "metadata": {
 | 
				
			||||||
 | 
					  "interpreter": {
 | 
				
			||||||
 | 
					   "hash": "c8daecebaf2d81430b8373e4b4af380b12df116248cd1bbadd3fc947f45a1f88"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "kernelspec": {
 | 
				
			||||||
 | 
					   "display_name": "Python 3.8.6 64-bit",
 | 
				
			||||||
 | 
					   "name": "python3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "language_info": {
 | 
				
			||||||
 | 
					   "codemirror_mode": {
 | 
				
			||||||
 | 
					    "name": "ipython",
 | 
				
			||||||
 | 
					    "version": 3
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "file_extension": ".py",
 | 
				
			||||||
 | 
					   "mimetype": "text/x-python",
 | 
				
			||||||
 | 
					   "name": "python",
 | 
				
			||||||
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
 | 
					   "version": "3.8.6"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "orig_nbformat": 4
 | 
				
			||||||
 | 
					 },
 | 
				
			||||||
 | 
					 "nbformat": 4,
 | 
				
			||||||
 | 
					 "nbformat_minor": 2
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										354
									
								
								notebooks/playground.ipynb
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										354
									
								
								notebooks/playground.ipynb
									
									
									
									
										vendored
									
									
								
							@@ -1,354 +0,0 @@
 | 
				
			|||||||
{
 | 
					 | 
				
			||||||
 "cells": [
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 1,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "from datasets import load_dataset, Features\n",
 | 
					 | 
				
			||||||
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n",
 | 
					 | 
				
			||||||
    "import os\n",
 | 
					 | 
				
			||||||
    "from os.path import join\n",
 | 
					 | 
				
			||||||
    "import pandas as pd\n",
 | 
					 | 
				
			||||||
    "from datasets import ClassLabel, Value\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "# https://huggingface.co/docs/datasets/loading_datasets.html\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "DATASET_DIR = os.environ['DATASET_DIR']"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "markdown",
 | 
					 | 
				
			||||||
   "source": [],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 2,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "dataset = load_dataset(\"holodata/sensai\", features=Features(\n",
 | 
					 | 
				
			||||||
    "                {\n",
 | 
					 | 
				
			||||||
    "                    \"body\": Value(\"string\"),\n",
 | 
					 | 
				
			||||||
    "                    \"toxic\": ClassLabel(num_classes=2, names=['0', '1'])\n",
 | 
					 | 
				
			||||||
    "                }\n",
 | 
					 | 
				
			||||||
    "            ))\n",
 | 
					 | 
				
			||||||
    "dataset = dataset['train']"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "display_data",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "application/vnd.jupyter.widget-view+json": {
 | 
					 | 
				
			||||||
       "version_major": 2,
 | 
					 | 
				
			||||||
       "version_minor": 0,
 | 
					 | 
				
			||||||
       "model_id": "704789c10f1e44ddbb83262b8a826eec"
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "  0%|          | 0/1 [00:00<?, ?it/s]"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {}
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "stream",
 | 
					 | 
				
			||||||
     "name": "stderr",
 | 
					 | 
				
			||||||
     "text": [
 | 
					 | 
				
			||||||
      "Using custom data configuration sensai-4d9ed81389161083\n",
 | 
					 | 
				
			||||||
      "Reusing dataset parquet (/home/uetchy/.cache/huggingface/datasets/parquet/sensai-4d9ed81389161083/0.0.0/9296ce43568b20d72ff8ff8ecbc821a16b68e9b8b7058805ef11f06e035f911a)\n"
 | 
					 | 
				
			||||||
     ]
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "display_data",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "application/vnd.jupyter.widget-view+json": {
 | 
					 | 
				
			||||||
       "version_major": 2,
 | 
					 | 
				
			||||||
       "version_minor": 0,
 | 
					 | 
				
			||||||
       "model_id": "8acdde0c4caa4d4698f97ef29993195b"
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "  0%|          | 0/1 [00:00<?, ?it/s]"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {}
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 3,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "dataset.features"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "execute_result",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "{'body': Value(dtype='string', id=None),\n",
 | 
					 | 
				
			||||||
       " 'toxic': ClassLabel(num_classes=2, names=['0', '1'], names_file=None, id=None)}"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {},
 | 
					 | 
				
			||||||
     "execution_count": 3
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 4,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n",
 | 
					 | 
				
			||||||
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "stream",
 | 
					 | 
				
			||||||
     "name": "stderr",
 | 
					 | 
				
			||||||
     "text": [
 | 
					 | 
				
			||||||
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n",
 | 
					 | 
				
			||||||
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
 | 
					 | 
				
			||||||
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
 | 
					 | 
				
			||||||
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
 | 
					 | 
				
			||||||
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
 | 
					 | 
				
			||||||
     ]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 5,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "samples = dataset.shuffle().select(range(50000))\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "def tokenize_function(examples):\n",
 | 
					 | 
				
			||||||
    "    return tokenizer(examples[\"body\"], padding=\"max_length\", truncation=True)\n",
 | 
					 | 
				
			||||||
    "\n",
 | 
					 | 
				
			||||||
    "tokenized_datasets = samples.map(tokenize_function, batched=True)\n",
 | 
					 | 
				
			||||||
    "tokenized_datasets.rename_column_(\"toxic\", \"label\")"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "stream",
 | 
					 | 
				
			||||||
     "name": "stderr",
 | 
					 | 
				
			||||||
     "text": [
 | 
					 | 
				
			||||||
      "Loading cached shuffled indices for dataset at /home/uetchy/.cache/huggingface/datasets/parquet/sensai-4d9ed81389161083/0.0.0/9296ce43568b20d72ff8ff8ecbc821a16b68e9b8b7058805ef11f06e035f911a/cache-24e3dd769ef2f1b7.arrow\n",
 | 
					 | 
				
			||||||
      "Loading cached processed dataset at /home/uetchy/.cache/huggingface/datasets/parquet/sensai-4d9ed81389161083/0.0.0/9296ce43568b20d72ff8ff8ecbc821a16b68e9b8b7058805ef11f06e035f911a/cache-8395f066c72e57d7.arrow\n",
 | 
					 | 
				
			||||||
      "/tmp/ipykernel_4082765/2982913603.py:7: FutureWarning: rename_column_ is deprecated and will be removed in the next major version of datasets. Use Dataset.rename_column instead.\n",
 | 
					 | 
				
			||||||
      "  tokenized_datasets.rename_column_(\"toxic\", \"label\")\n"
 | 
					 | 
				
			||||||
     ]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 6,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "splitset = tokenized_datasets.train_test_split(0.2)\n",
 | 
					 | 
				
			||||||
    "splitset"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "execute_result",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "DatasetDict({\n",
 | 
					 | 
				
			||||||
       "    train: Dataset({\n",
 | 
					 | 
				
			||||||
       "        features: ['attention_mask', 'body', 'input_ids', 'token_type_ids', 'label'],\n",
 | 
					 | 
				
			||||||
       "        num_rows: 40000\n",
 | 
					 | 
				
			||||||
       "    })\n",
 | 
					 | 
				
			||||||
       "    test: Dataset({\n",
 | 
					 | 
				
			||||||
       "        features: ['attention_mask', 'body', 'input_ids', 'token_type_ids', 'label'],\n",
 | 
					 | 
				
			||||||
       "        num_rows: 10000\n",
 | 
					 | 
				
			||||||
       "    })\n",
 | 
					 | 
				
			||||||
       "})"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {},
 | 
					 | 
				
			||||||
     "execution_count": 6
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 7,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "training_args = TrainingArguments(\"test_trainer\")\n",
 | 
					 | 
				
			||||||
    "trainer = Trainer(\n",
 | 
					 | 
				
			||||||
    "    model=model, args=training_args, train_dataset=splitset['train'], eval_dataset=splitset['test']\n",
 | 
					 | 
				
			||||||
    ")"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": 9,
 | 
					 | 
				
			||||||
   "source": [
 | 
					 | 
				
			||||||
    "trainer.train(resume_from_checkpoint=True)"
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "outputs": [
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "stream",
 | 
					 | 
				
			||||||
     "name": "stderr",
 | 
					 | 
				
			||||||
     "text": [
 | 
					 | 
				
			||||||
      "Loading model from test_trainer/checkpoint-12500).\n",
 | 
					 | 
				
			||||||
      "The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: body.\n",
 | 
					 | 
				
			||||||
      "***** Running training *****\n",
 | 
					 | 
				
			||||||
      "  Num examples = 40000\n",
 | 
					 | 
				
			||||||
      "  Num Epochs = 3\n",
 | 
					 | 
				
			||||||
      "  Instantaneous batch size per device = 8\n",
 | 
					 | 
				
			||||||
      "  Total train batch size (w. parallel, distributed & accumulation) = 8\n",
 | 
					 | 
				
			||||||
      "  Gradient Accumulation steps = 1\n",
 | 
					 | 
				
			||||||
      "  Total optimization steps = 15000\n",
 | 
					 | 
				
			||||||
      "  Continuing training from checkpoint, will skip to saved global_step\n",
 | 
					 | 
				
			||||||
      "  Continuing training from epoch 2\n",
 | 
					 | 
				
			||||||
      "  Continuing training from global step 12500\n",
 | 
					 | 
				
			||||||
      "  Will skip the first 2 epochs then the first 2500 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.\n"
 | 
					 | 
				
			||||||
     ]
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "display_data",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "application/vnd.jupyter.widget-view+json": {
 | 
					 | 
				
			||||||
       "version_major": 2,
 | 
					 | 
				
			||||||
       "version_minor": 0,
 | 
					 | 
				
			||||||
       "model_id": "d749dea90f914485916ce82c65e73fe6"
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "  0%|          | 0/2500 [00:00<?, ?it/s]"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {}
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "stream",
 | 
					 | 
				
			||||||
     "name": "stderr",
 | 
					 | 
				
			||||||
     "text": [
 | 
					 | 
				
			||||||
      "Didn't find an RNG file, if you are resuming a training that was launched in a distributed fashion, reproducibility is not guaranteed.\n"
 | 
					 | 
				
			||||||
     ]
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "display_data",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "text/html": [
 | 
					 | 
				
			||||||
       "\n",
 | 
					 | 
				
			||||||
       "    <div>\n",
 | 
					 | 
				
			||||||
       "      \n",
 | 
					 | 
				
			||||||
       "      <progress value='15000' max='15000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
 | 
					 | 
				
			||||||
       "      [15000/15000 31:45, Epoch 3/3]\n",
 | 
					 | 
				
			||||||
       "    </div>\n",
 | 
					 | 
				
			||||||
       "    <table border=\"1\" class=\"dataframe\">\n",
 | 
					 | 
				
			||||||
       "  <thead>\n",
 | 
					 | 
				
			||||||
       "    <tr style=\"text-align: left;\">\n",
 | 
					 | 
				
			||||||
       "      <th>Step</th>\n",
 | 
					 | 
				
			||||||
       "      <th>Training Loss</th>\n",
 | 
					 | 
				
			||||||
       "    </tr>\n",
 | 
					 | 
				
			||||||
       "  </thead>\n",
 | 
					 | 
				
			||||||
       "  <tbody>\n",
 | 
					 | 
				
			||||||
       "    <tr>\n",
 | 
					 | 
				
			||||||
       "      <td>13000</td>\n",
 | 
					 | 
				
			||||||
       "      <td>0.687500</td>\n",
 | 
					 | 
				
			||||||
       "    </tr>\n",
 | 
					 | 
				
			||||||
       "    <tr>\n",
 | 
					 | 
				
			||||||
       "      <td>13500</td>\n",
 | 
					 | 
				
			||||||
       "      <td>0.686300</td>\n",
 | 
					 | 
				
			||||||
       "    </tr>\n",
 | 
					 | 
				
			||||||
       "    <tr>\n",
 | 
					 | 
				
			||||||
       "      <td>14000</td>\n",
 | 
					 | 
				
			||||||
       "      <td>0.637900</td>\n",
 | 
					 | 
				
			||||||
       "    </tr>\n",
 | 
					 | 
				
			||||||
       "    <tr>\n",
 | 
					 | 
				
			||||||
       "      <td>14500</td>\n",
 | 
					 | 
				
			||||||
       "      <td>0.643200</td>\n",
 | 
					 | 
				
			||||||
       "    </tr>\n",
 | 
					 | 
				
			||||||
       "    <tr>\n",
 | 
					 | 
				
			||||||
       "      <td>15000</td>\n",
 | 
					 | 
				
			||||||
       "      <td>0.627700</td>\n",
 | 
					 | 
				
			||||||
       "    </tr>\n",
 | 
					 | 
				
			||||||
       "  </tbody>\n",
 | 
					 | 
				
			||||||
       "</table><p>"
 | 
					 | 
				
			||||||
      ],
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "<IPython.core.display.HTML object>"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {}
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "stream",
 | 
					 | 
				
			||||||
     "name": "stderr",
 | 
					 | 
				
			||||||
     "text": [
 | 
					 | 
				
			||||||
      "Saving model checkpoint to test_trainer/checkpoint-13000\n",
 | 
					 | 
				
			||||||
      "Configuration saved in test_trainer/checkpoint-13000/config.json\n",
 | 
					 | 
				
			||||||
      "Model weights saved in test_trainer/checkpoint-13000/pytorch_model.bin\n",
 | 
					 | 
				
			||||||
      "Saving model checkpoint to test_trainer/checkpoint-13500\n",
 | 
					 | 
				
			||||||
      "Configuration saved in test_trainer/checkpoint-13500/config.json\n",
 | 
					 | 
				
			||||||
      "Model weights saved in test_trainer/checkpoint-13500/pytorch_model.bin\n",
 | 
					 | 
				
			||||||
      "Saving model checkpoint to test_trainer/checkpoint-14000\n",
 | 
					 | 
				
			||||||
      "Configuration saved in test_trainer/checkpoint-14000/config.json\n",
 | 
					 | 
				
			||||||
      "Model weights saved in test_trainer/checkpoint-14000/pytorch_model.bin\n",
 | 
					 | 
				
			||||||
      "Saving model checkpoint to test_trainer/checkpoint-14500\n",
 | 
					 | 
				
			||||||
      "Configuration saved in test_trainer/checkpoint-14500/config.json\n",
 | 
					 | 
				
			||||||
      "Model weights saved in test_trainer/checkpoint-14500/pytorch_model.bin\n",
 | 
					 | 
				
			||||||
      "Saving model checkpoint to test_trainer/checkpoint-15000\n",
 | 
					 | 
				
			||||||
      "Configuration saved in test_trainer/checkpoint-15000/config.json\n",
 | 
					 | 
				
			||||||
      "Model weights saved in test_trainer/checkpoint-15000/pytorch_model.bin\n",
 | 
					 | 
				
			||||||
      "\n",
 | 
					 | 
				
			||||||
      "\n",
 | 
					 | 
				
			||||||
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
 | 
					 | 
				
			||||||
      "\n",
 | 
					 | 
				
			||||||
      "\n"
 | 
					 | 
				
			||||||
     ]
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
     "output_type": "execute_result",
 | 
					 | 
				
			||||||
     "data": {
 | 
					 | 
				
			||||||
      "text/plain": [
 | 
					 | 
				
			||||||
       "TrainOutput(global_step=15000, training_loss=0.10941998901367188, metrics={'train_runtime': 1918.0916, 'train_samples_per_second': 62.562, 'train_steps_per_second': 7.82, 'total_flos': 3.24994775580672e+16, 'train_loss': 0.10941998901367188, 'epoch': 3.0})"
 | 
					 | 
				
			||||||
      ]
 | 
					 | 
				
			||||||
     },
 | 
					 | 
				
			||||||
     "metadata": {},
 | 
					 | 
				
			||||||
     "execution_count": 9
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
   ],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
   "cell_type": "code",
 | 
					 | 
				
			||||||
   "execution_count": null,
 | 
					 | 
				
			||||||
   "source": [],
 | 
					 | 
				
			||||||
   "outputs": [],
 | 
					 | 
				
			||||||
   "metadata": {}
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 ],
 | 
					 | 
				
			||||||
 "metadata": {
 | 
					 | 
				
			||||||
  "orig_nbformat": 4,
 | 
					 | 
				
			||||||
  "language_info": {
 | 
					 | 
				
			||||||
   "name": "python",
 | 
					 | 
				
			||||||
   "version": "3.8.6",
 | 
					 | 
				
			||||||
   "mimetype": "text/x-python",
 | 
					 | 
				
			||||||
   "codemirror_mode": {
 | 
					 | 
				
			||||||
    "name": "ipython",
 | 
					 | 
				
			||||||
    "version": 3
 | 
					 | 
				
			||||||
   },
 | 
					 | 
				
			||||||
   "pygments_lexer": "ipython3",
 | 
					 | 
				
			||||||
   "nbconvert_exporter": "python",
 | 
					 | 
				
			||||||
   "file_extension": ".py"
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  "kernelspec": {
 | 
					 | 
				
			||||||
   "name": "python3",
 | 
					 | 
				
			||||||
   "display_name": "Python 3.8.6 64-bit"
 | 
					 | 
				
			||||||
  },
 | 
					 | 
				
			||||||
  "interpreter": {
 | 
					 | 
				
			||||||
   "hash": "c8daecebaf2d81430b8373e4b4af380b12df116248cd1bbadd3fc947f45a1f88"
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 },
 | 
					 | 
				
			||||||
 "nbformat": 4,
 | 
					 | 
				
			||||||
 "nbformat_minor": 2
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										2941
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2941
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -8,18 +8,20 @@ authors = ["Yasuaki Uechi <y@uechi.io>"]
 | 
				
			|||||||
python = "^3.8"
 | 
					python = "^3.8"
 | 
				
			||||||
datasets = {git = "https://github.com/huggingface/datasets.git", rev = "master"}
 | 
					datasets = {git = "https://github.com/huggingface/datasets.git", rev = "master"}
 | 
				
			||||||
#datasets = "^1.11.0"
 | 
					#datasets = "^1.11.0"
 | 
				
			||||||
pyarrow = "^5.0.0"
 | 
					pyarrow = "^7.0.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.dev-dependencies]
 | 
					[tool.poetry.dev-dependencies]
 | 
				
			||||||
pymongo = "^3.11.3"
 | 
					pymongo = "^4.0.2"
 | 
				
			||||||
transformers = "^4.10.0"
 | 
					transformers = "^4.17.0"
 | 
				
			||||||
torch = "^1.9.0"
 | 
					torch = "^1.11.0"
 | 
				
			||||||
tokenizers = "^0.10.3"
 | 
					tokenizers = "^0.11.6"
 | 
				
			||||||
kaggle = "^1.5.12"
 | 
					kaggle = "^1.5.12"
 | 
				
			||||||
pandas = "^1.2.3"
 | 
					pandas = "^1.4.1"
 | 
				
			||||||
python-dateutil = "^2.8.1"
 | 
					python-dateutil = "^2.8.2"
 | 
				
			||||||
ipykernel = "^6.2.0"
 | 
					yapf = "^0.32.0"
 | 
				
			||||||
yapf = "^0.31.0"
 | 
					sentence-transformers = "^2.2.0"
 | 
				
			||||||
 | 
					jupyter = "^1.0.0"
 | 
				
			||||||
 | 
					tangram = "^0.7.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[build-system]
 | 
					[build-system]
 | 
				
			||||||
requires = ["poetry-core>=1.0.0"]
 | 
					requires = ["poetry-core>=1.0.0"]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,16 +1,21 @@
 | 
				
			|||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from sensai_dataset.generator.commands import generate_dataset
 | 
					from sensai_dataset.generator.commands import generate_dataset, generate_reduced_dataset
 | 
				
			||||||
from sensai_dataset.generator.constants import DATASET_DIR, DATASET_SOURCE_DIR
 | 
					from sensai_dataset.generator.constants import SENSAI_COMPLETE_DIR, SENSAI_DIR, DATASET_SOURCE_DIR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='dataset generator')
 | 
					    parser = argparse.ArgumentParser(description='dataset generator')
 | 
				
			||||||
    parser.add_argument('-m', '--matcher', type=str, default='chats_*.csv')
 | 
					    parser.add_argument('-m', '--matcher', type=str, default='chats_*.parquet')
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print('target: ' + DATASET_DIR)
 | 
					 | 
				
			||||||
    print('source: ' + DATASET_SOURCE_DIR)
 | 
					    print('source: ' + DATASET_SOURCE_DIR)
 | 
				
			||||||
 | 
					    print('SENSAI_COMPLETE_DIR: ' + SENSAI_COMPLETE_DIR)
 | 
				
			||||||
 | 
					    print('SENSAI_DIR: ' + SENSAI_DIR)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    generate_dataset(source_dir=DATASET_SOURCE_DIR,
 | 
					    generate_dataset(source_dir=DATASET_SOURCE_DIR,
 | 
				
			||||||
                     target_dir=DATASET_DIR,
 | 
					                     target_dir=SENSAI_COMPLETE_DIR,
 | 
				
			||||||
                     matcher=args.matcher)
 | 
					                     matcher=args.matcher)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generate_reduced_dataset(source_dir=DATASET_SOURCE_DIR,
 | 
				
			||||||
 | 
					                             target_dir=SENSAI_DIR,
 | 
				
			||||||
 | 
					                             matcher=args.matcher)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,16 +6,17 @@ import pandas as pd
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_dataset(source_dir, target_dir, matcher):
 | 
					def generate_dataset(source_dir, target_dir, matcher):
 | 
				
			||||||
    print('[generate_sensai_dataset]')
 | 
					    print('[generate_dataset]')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    delet_path = join(source_dir, 'deletion_events.csv')
 | 
					    delet_path = join(source_dir, 'deletion_events.parquet')
 | 
				
			||||||
    del_events = pd.read_csv(delet_path, usecols=['id', 'retracted'])
 | 
					    del_events = pd.read_parquet(delet_path, columns=['id', 'retracted'])
 | 
				
			||||||
    del_events = del_events.query('retracted == 0').copy()
 | 
					    del_events = del_events.query('retracted == 0').copy()
 | 
				
			||||||
    del_events.drop(columns=['retracted'], inplace=True)
 | 
					    del_events.drop(columns=['retracted'], inplace=True)
 | 
				
			||||||
    del_events['label'] = 'deleted'
 | 
					    del_events['label'] = 'deleted'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ban_path = join(source_dir, 'ban_events.csv')
 | 
					    ban_path = join(source_dir, 'ban_events.parquet')
 | 
				
			||||||
    ban_events = pd.read_csv(ban_path, usecols=['authorChannelId', 'videoId'])
 | 
					    ban_events = pd.read_parquet(ban_path,
 | 
				
			||||||
 | 
					                                 columns=['authorChannelId', 'videoId'])
 | 
				
			||||||
    ban_events['label'] = 'hidden'
 | 
					    ban_events['label'] = 'hidden'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for f in sorted(iglob(join(source_dir, matcher))):
 | 
					    for f in sorted(iglob(join(source_dir, matcher))):
 | 
				
			||||||
@@ -24,17 +25,93 @@ def generate_dataset(source_dir, target_dir, matcher):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # load chat
 | 
					        # load chat
 | 
				
			||||||
        print('>>> Loading chats')
 | 
					        print('>>> Loading chats')
 | 
				
			||||||
        chat_path = join(source_dir, 'chats_' + period_string + '.csv')
 | 
					        chat_path = join(source_dir, 'chats_' + period_string + '.parquet')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        chats = pd.read_csv(chat_path,
 | 
					        chats = pd.read_parquet(
 | 
				
			||||||
                            na_values='',
 | 
					            chat_path,
 | 
				
			||||||
                            keep_default_na=False,
 | 
					            columns=['authorChannelId', 'videoId', 'id', 'authorName', 'body'])
 | 
				
			||||||
                            usecols=[
 | 
					
 | 
				
			||||||
                                'authorChannelId',
 | 
					        # remove NA
 | 
				
			||||||
                                'videoId',
 | 
					        chats = chats[chats['body'].notna()]
 | 
				
			||||||
                                'id',
 | 
					
 | 
				
			||||||
                                'body',
 | 
					        # apply mods
 | 
				
			||||||
                            ])
 | 
					        print('>>> Merging bans')
 | 
				
			||||||
 | 
					        chats = pd.merge(chats,
 | 
				
			||||||
 | 
					                         ban_events,
 | 
				
			||||||
 | 
					                         on=['authorChannelId', 'videoId'],
 | 
				
			||||||
 | 
					                         how='left')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # apply mods
 | 
				
			||||||
 | 
					        print('>>> Merging deletion')
 | 
				
			||||||
 | 
					        chats.loc[chats['id'].isin(del_events['id']), 'label'] = 'deleted'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # apply safe
 | 
				
			||||||
 | 
					        print('>>> Applying safe')
 | 
				
			||||||
 | 
					        chats['label'].fillna('nonflagged', inplace=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        isFlagged = chats['label'] != 'nonflagged'
 | 
				
			||||||
 | 
					        flagged = chats[isFlagged].copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # to make balanced dataset
 | 
				
			||||||
 | 
					        nbFlagged = flagged.shape[0]
 | 
				
			||||||
 | 
					        if nbFlagged == 0:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print('>>> Sampling nonflagged chats')
 | 
				
			||||||
 | 
					        print('nbFlagged', nbFlagged)
 | 
				
			||||||
 | 
					        nonflag = chats[~isFlagged].sample(nbFlagged)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print('>>> Writing dataset')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # NOTE: do not use categorical type with to_parquest. otherwise, it will be failed to load them with huggingface's Dataset
 | 
				
			||||||
 | 
					        columns_to_delete = [
 | 
				
			||||||
 | 
					            'authorChannelId',
 | 
				
			||||||
 | 
					            'videoId',
 | 
				
			||||||
 | 
					            'id',
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        flagged.drop(columns=columns_to_delete, inplace=True)
 | 
				
			||||||
 | 
					        flagged.to_parquet(join(target_dir,
 | 
				
			||||||
 | 
					                                f'chats_flagged_{period_string}.parquet'),
 | 
				
			||||||
 | 
					                           index=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nonflag.drop(columns=columns_to_delete, inplace=True)
 | 
				
			||||||
 | 
					        nonflag.to_parquet(join(target_dir,
 | 
				
			||||||
 | 
					                                f'chats_nonflag_{period_string}.parquet'),
 | 
				
			||||||
 | 
					                           index=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # free up memory
 | 
				
			||||||
 | 
					        del nonflag
 | 
				
			||||||
 | 
					        del flagged
 | 
				
			||||||
 | 
					        del chats
 | 
				
			||||||
 | 
					        gc.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def generate_reduced_dataset(source_dir, target_dir, matcher):
 | 
				
			||||||
 | 
					    print('[generate_reduced_dataset]')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    delet_path = join(source_dir, 'deletion_events.parquet')
 | 
				
			||||||
 | 
					    del_events = pd.read_parquet(delet_path, columns=['id', 'retracted'])
 | 
				
			||||||
 | 
					    del_events = del_events.query('retracted == 0').copy()
 | 
				
			||||||
 | 
					    del_events.drop(columns=['retracted'], inplace=True)
 | 
				
			||||||
 | 
					    del_events['label'] = 'deleted'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ban_path = join(source_dir, 'ban_events.parquet')
 | 
				
			||||||
 | 
					    ban_events = pd.read_parquet(ban_path,
 | 
				
			||||||
 | 
					                                 columns=['authorChannelId', 'videoId'])
 | 
				
			||||||
 | 
					    ban_events['label'] = 'hidden'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for f in sorted(iglob(join(source_dir, matcher))):
 | 
				
			||||||
 | 
					        period_string = splitext(basename(f))[0].split('_')[1]
 | 
				
			||||||
 | 
					        print('>>> Period:', period_string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # load chat
 | 
				
			||||||
 | 
					        print('>>> Loading chats')
 | 
				
			||||||
 | 
					        chat_path = join(source_dir, 'chats_' + period_string + '.parquet')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        chats = pd.read_parquet(
 | 
				
			||||||
 | 
					            chat_path,
 | 
				
			||||||
 | 
					            columns=['authorChannelId', 'videoId', 'id', 'authorName', 'body'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # remove NA
 | 
					        # remove NA
 | 
				
			||||||
        chats = chats[chats['body'].notna()]
 | 
					        chats = chats[chats['body'].notna()]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,5 @@
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DATASET_DIR = os.environ['DATASET_DIR']
 | 
					SENSAI_DIR = os.environ['SENSAI_DIR']
 | 
				
			||||||
 | 
					SENSAI_COMPLETE_DIR = os.environ['SENSAI_COMPLETE_DIR']
 | 
				
			||||||
DATASET_SOURCE_DIR = os.environ['DATASET_SOURCE_DIR']
 | 
					DATASET_SOURCE_DIR = os.environ['DATASET_SOURCE_DIR']
 | 
				
			||||||
 | 
					 | 
				
			||||||
os.makedirs(DATASET_DIR, exist_ok=True)
 | 
					 | 
				
			||||||
os.makedirs(DATASET_SOURCE_DIR, exist_ok=True)
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user