Skip to content

Commit

Permalink
cambios 18-5-24
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlosespicur committed May 18, 2024
1 parent 9d93972 commit b98a13a
Show file tree
Hide file tree
Showing 5 changed files with 3,559 additions and 1 deletion.
2 changes: 1 addition & 1 deletion GAE_fashion_MNIST_pruebas.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Pruebas_encoder_SLMVP.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Pruebas_encoder_Triplet_loss.ipynb

Large diffs are not rendered by default.

186 changes: 186 additions & 0 deletions centroides.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Carlos\\anaconda3\\envs\\pytorch\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def binary_distance(X, Y):\n",
" \"\"\"Compute distance matrix between rows of X, Y.\n",
"\n",
" d(x_i, y_j) = 1 if x_i == y_j, 0 in other case.\n",
"\n",
" for all rows x_i in X, y_j in Y\n",
"\n",
" \"\"\"\n",
" return (X.unsqueeze(1) == Y.unsqueeze(0)).all(-1).float()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def compute_centroids(X, y):\n",
" label_matrix = binary_distance(y.unique().unsqueeze(1), y.unsqueeze(1))\n",
" label_matrix = torch.div(label_matrix, torch.sum(label_matrix, dim=1).unsqueeze(1))\n",
" return torch.matmul(label_matrix, X), y.unique()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"prueba = torch.tensor([2,3,2,2,1,2,3,4], dtype=torch.float)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"etiquetas = prueba.unique()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"matriz = binary_distance(etiquetas.unsqueeze(1), prueba.unsqueeze(1))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 1., 0., 0., 0.],\n",
" [1., 0., 1., 1., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 1.]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matriz"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"X = torch.rand((5,6), dtype=torch.float)\n",
"y = torch.tensor([2,3,2,2,1], dtype=torch.float)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.8297, 0.3224, 0.5171, 0.4091, 0.2864, 0.7082],\n",
" [0.8326, 0.8648, 0.8250, 0.7273, 0.6693, 0.8729],\n",
" [0.8528, 0.2973, 0.4668, 0.0865, 0.9033, 0.1066],\n",
" [0.0826, 0.5666, 0.1821, 0.4346, 0.9174, 0.1333],\n",
" [0.2146, 0.4583, 0.8526, 0.2099, 0.6207, 0.7287]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0.2146, 0.4583, 0.8526, 0.2099, 0.6207, 0.7287],\n",
" [0.5884, 0.3954, 0.3887, 0.3101, 0.7024, 0.3160],\n",
" [0.8326, 0.8648, 0.8250, 0.7273, 0.6693, 0.8729]]),\n",
" tensor([1., 2., 3.]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compute_centroids(X,y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit b98a13a

Please sign in to comment.