Skip to content

Commit

Permalink
outliers & evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
zk committed Aug 14, 2024
1 parent 8632983 commit d6ba25f
Show file tree
Hide file tree
Showing 4 changed files with 471 additions and 133 deletions.
310 changes: 294 additions & 16 deletions data_centric_evaluation/Lab - Data-Centric Evaluation.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

142 changes: 79 additions & 63 deletions dataset_curation/Lab - Dataset Curation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: cleanlab in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (2.2.0)\n",
"Requirement already satisfied: tqdm>=4.53.0 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from cleanlab) (4.64.1)\n",
"Requirement already satisfied: termcolor>=1.1.0 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from cleanlab) (2.2.0)\n",
"Requirement already satisfied: numpy>=1.11.3 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from cleanlab) (1.24.1)\n",
"Requirement already satisfied: scikit-learn>=0.18 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from cleanlab) (1.2.0)\n",
"Requirement already satisfied: pandas>=1.0.0 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from cleanlab) (1.5.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from pandas>=1.0.0->cleanlab) (2022.7)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from pandas>=1.0.0->cleanlab) (2.8.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from scikit-learn>=0.18->cleanlab) (3.1.0)\n",
"Requirement already satisfied: joblib>=1.1.1 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from scikit-learn>=0.18->cleanlab) (1.2.0)\n",
"Requirement already satisfied: scipy>=1.3.2 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from scikit-learn>=0.18->cleanlab) (1.10.0)\n",
"Requirement already satisfied: six>=1.5 in /Users/jonas/virtual/dcaiclass/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas>=1.0.0->cleanlab) (1.16.0)\n"
"Requirement already satisfied: cleanlab in /root/anaconda3/lib/python3.11/site-packages (2.6.6)\n",
"Requirement already satisfied: numpy~=1.22 in /root/anaconda3/lib/python3.11/site-packages (from cleanlab) (1.26.4)\n",
"Requirement already satisfied: scikit-learn>=1.1 in /root/anaconda3/lib/python3.11/site-packages (from cleanlab) (1.2.2)\n",
"Requirement already satisfied: tqdm>=4.53.0 in /root/anaconda3/lib/python3.11/site-packages (from cleanlab) (4.66.4)\n",
"Requirement already satisfied: pandas>=1.4.0 in /root/anaconda3/lib/python3.11/site-packages (from cleanlab) (2.1.4)\n",
"Requirement already satisfied: termcolor>=2.4.0 in /root/anaconda3/lib/python3.11/site-packages (from cleanlab) (2.4.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /root/anaconda3/lib/python3.11/site-packages (from pandas>=1.4.0->cleanlab) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /root/anaconda3/lib/python3.11/site-packages (from pandas>=1.4.0->cleanlab) (2023.3.post1)\n",
"Requirement already satisfied: tzdata>=2022.1 in /root/anaconda3/lib/python3.11/site-packages (from pandas>=1.4.0->cleanlab) (2023.3)\n",
"Requirement already satisfied: scipy>=1.3.2 in /root/anaconda3/lib/python3.11/site-packages (from scikit-learn>=1.1->cleanlab) (1.11.4)\n",
"Requirement already satisfied: joblib>=1.1.1 in /root/anaconda3/lib/python3.11/site-packages (from scikit-learn>=1.1->cleanlab) (1.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /root/anaconda3/lib/python3.11/site-packages (from scikit-learn>=1.1->cleanlab) (2.2.0)\n",
"Requirement already satisfied: six>=1.5 in /root/anaconda3/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas>=1.4.0->cleanlab) (1.16.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
Expand Down Expand Up @@ -143,10 +146,21 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"id": "9c9e83da",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['X_train', 'true_labels_train', 'multiannotator_labels'])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_dict = make_data(sample_size = 300)\n",
"\n",
Expand All @@ -165,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "aa80889d",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -215,9 +229,9 @@
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>247</th>\n",
" <th>49</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>2</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
Expand All @@ -239,14 +253,14 @@
" <td>&lt;NA&gt;</td>\n",
" </tr>\n",
" <tr>\n",
" <th>290</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <th>188</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>1</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
Expand All @@ -257,49 +271,49 @@
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>1</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" </tr>\n",
" <tr>\n",
" <th>262</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <th>50</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>2</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>0</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>0</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" </tr>\n",
" <tr>\n",
" <th>182</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <th>55</th>\n",
" <td>0</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>...</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>0</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
Expand All @@ -311,14 +325,14 @@
" <td>&lt;NA&gt;</td>\n",
" </tr>\n",
" <tr>\n",
" <th>143</th>\n",
" <th>193</th>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>0</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>1</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
" <td>&lt;NA&gt;</td>\n",
Expand All @@ -341,23 +355,23 @@
],
"text/plain": [
" A0001 A0002 A0003 A0004 A0005 A0006 A0007 A0008 A0009 A0010 \\\n",
"247 <NA> 2 <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"290 <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"262 <NA> <NA> <NA> <NA> <NA> <NA> 2 <NA> <NA> <NA> \n",
"182 <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"143 <NA> <NA> 0 <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"49 <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"188 <NA> <NA> <NA> <NA> <NA> <NA> 1 <NA> <NA> <NA> \n",
"50 <NA> <NA> <NA> <NA> <NA> <NA> 0 <NA> <NA> 0 \n",
"55 0 <NA> <NA> <NA> <NA> <NA> <NA> 1 0 <NA> \n",
"193 <NA> <NA> <NA> <NA> <NA> <NA> 1 <NA> <NA> <NA> \n",
"\n",
" ... A0041 A0042 A0043 A0044 A0045 A0046 A0047 A0048 A0049 A0050 \n",
"247 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"290 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"262 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"182 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"143 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"49 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"188 ... <NA> <NA> <NA> <NA> <NA> <NA> 1 <NA> <NA> <NA> \n",
"50 ... <NA> <NA> <NA> <NA> 0 <NA> <NA> <NA> <NA> <NA> \n",
"55 ... 0 <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"193 ... <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> <NA> \n",
"\n",
"[5 rows x 50 columns]"
]
},
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -376,21 +390,21 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 11,
"id": "f5a59e23",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.01592896, 10.62213634],\n",
" [-1.91393643, 6.53944268],\n",
" [ 0.55962291, 5.35885902],\n",
" [ 6.73677377, 5.02311322],\n",
" [ 6.95949986, 1.61434817]])"
"array([[4.39299977, 0.98071378],\n",
" [6.35764575, 5.18249508],\n",
" [3.05336749, 3.19257978],\n",
" [4.83761843, 1.59196404],\n",
" [7.08868044, 7.70641701]])"
]
},
"execution_count": 6,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -411,7 +425,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 12,
"id": "77527d5c",
"metadata": {},
"outputs": [],
Expand All @@ -430,7 +444,7 @@
" model = KNeighborsClassifier(weights=\"distance\")\n",
" pred_probs = cross_val_predict(\n",
" estimator=model, X=X, y=labels_to_fit, cv=num_crossval_folds, method=\"predict_proba\"\n",
" )\n",
" ) # 可以得到一个N * K的矩阵\n",
" class_predictions = np.argmax(pred_probs, axis=1)\n",
" held_out_accuracy = np.mean(class_predictions == true_labels)\n",
" print(f\"Accuracy of held-out model predictions against ground truth labels: {held_out_accuracy}\")\n",
Expand All @@ -447,27 +461,27 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 16,
"id": "0d6cde68",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of random annotators' labels against ground truth labels: 0.6822742474916388\n",
"Accuracy of held-out model predictions against ground truth labels: 0.8093645484949833\n"
"Accuracy of random annotators' labels against ground truth labels: 0.7003367003367004\n",
"Accuracy of held-out model predictions against ground truth labels: 0.8047138047138047\n"
]
}
],
"source": [
"labels_from_random_annotators = true_labels.copy()\n",
"for i in range(len(multiannotator_labels)):\n",
" annotations_for_example_i = multiannotator_labels.iloc[i][pd.notna(multiannotator_labels.iloc[i])]\n",
" annotations_for_example_i = multiannotator_labels.iloc[i][pd.notna(multiannotator_labels.iloc[i])] # 取出所有标注者对这个样本的标注\n",
" labels_from_random_annotators[i] = np.random.choice(annotations_for_example_i.values)\n",
"\n",
"print(f\"Accuracy of random annotators' labels against ground truth labels: {np.mean(labels_from_random_annotators == true_labels)}\")\n",
"pred_probs_from_model_fit_to_random_annotators = train_model(labels_to_fit = labels_from_random_annotators)\n"
"pred_probs_from_model_fit_to_random_annotators = train_model(labels_to_fit = labels_from_random_annotators)"
]
},
{
Expand All @@ -480,19 +494,20 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 18,
"id": "ee1a3d99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of held-out model predictions against ground truth labels: 0.9765886287625418\n"
"Accuracy of held-out model predictions against ground truth labels: 0.9461279461279462\n"
]
}
],
"source": [
"# 果然用正确数据就是准\n",
"pred_probs_from_unrealistic_model_fit_to_true_labels = train_model(labels_to_fit = true_labels)"
]
},
Expand All @@ -512,12 +527,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 30,
"id": "4ee57297",
"metadata": {},
"outputs": [],
"source": [
"## Code your solution here"
"## Code your solution here\n",
"# 选出每一个样本出现最多次的,然后开始训练\n"
]
},
{
Expand Down Expand Up @@ -548,9 +564,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "dcaiclass",
"display_name": "Python 3",
"language": "python",
"name": "dcaiclass"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -562,7 +578,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit d6ba25f

Please sign in to comment.