Skip to content

Commit

Permalink
Tutorial mnist: improve 1.3 Remove contradictory (#505)
Browse files Browse the repository at this point in the history
* Tutorial mnist: improve 1.3 Remove contradictory

* Tutorial mnist: improve 1.3 data report

Co-authored-by: MichaelBroughton <[email protected]>
  • Loading branch information
Bankde and MichaelBroughton authored Mar 19, 2021
1 parent 52ccd82 commit c49f0f1
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions docs/tutorials/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -382,32 +382,35 @@
"source": [
"def remove_contradicting(xs, ys):\n",
" mapping = collections.defaultdict(set)\n",
" orig_x = {}\n",
" # Determine the set of labels for each unique image:\n",
" for x,y in zip(xs,ys):\n",
" orig_x[tuple(x.flatten())] = x\n",
" mapping[tuple(x.flatten())].add(y)\n",
" \n",
" new_x = []\n",
" new_y = []\n",
" for x,y in zip(xs, ys):\n",
" labels = mapping[tuple(x.flatten())]\n",
" for flatten_x in mapping:\n",
" x = orig_x[flatten_x]\n",
" labels = mapping[flatten_x]\n",
" if len(labels) == 1:\n",
" new_x.append(x)\n",
" new_y.append(labels.pop())\n",
" new_y.append(next(iter(labels)))\n",
" else:\n",
" # Throw out images that match more than one label.\n",
" pass\n",
" \n",
" num_3 = sum(1 for value in mapping.values() if True in value)\n",
" num_6 = sum(1 for value in mapping.values() if False in value)\n",
" num_both = sum(1 for value in mapping.values() if len(value) == 2)\n",
" num_uniq_3 = sum(1 for value in mapping.values() if len(value) == 1 and True in value)\n",
" num_uniq_6 = sum(1 for value in mapping.values() if len(value) == 1 and False in value)\n",
" num_uniq_both = sum(1 for value in mapping.values() if len(value) == 2)\n",
"\n",
" print(\"Number of unique images:\", len(mapping.values()))\n",
" print(\"Number of 3s: \", num_3)\n",
" print(\"Number of 6s: \", num_6)\n",
" print(\"Number of contradictory images: \", num_both)\n",
" print(\"Number of unique 3s: \", num_uniq_3)\n",
" print(\"Number of unique 6s: \", num_uniq_6)\n",
" print(\"Number of unique contradicting labels (both 3 and 6): \", num_uniq_both)\n",
" print()\n",
" print(\"Initial number of examples: \", len(xs))\n",
" print(\"Remaining non-contradictory examples: \", len(new_x))\n",
" print(\"Initial number of images: \", len(xs))\n",
" print(\"Remaining non-contradicting unique images: \", len(new_x))\n",
" \n",
" return np.array(new_x), np.array(new_y)"
]
Expand Down

0 comments on commit c49f0f1

Please sign in to comment.