From 70042a552fc44a7a102fe582fbeeaddfefb8407f Mon Sep 17 00:00:00 2001 From: James Briggs Date: Thu, 29 Apr 2021 20:23:00 +0100 Subject: [PATCH] added first draft notebooks for similarity section --- .../03_bidirectional_attention.ipynb | 9 +- .../introduction/02_transformer_heads.ipynb | 7 + course/similarity/00_intro.ipynb | 32 + .../02_transformer_dense_vectors.ipynb | 679 ++++++++++++++++++ .../03_similarity_metrics.ipynb} | 0 .../04_calculating_similarity.ipynb | 105 +++ .../similarity/05_sentence_transformers.ipynb | 41 ++ 7 files changed, 872 insertions(+), 1 deletion(-) create mode 100644 course/similarity/00_intro.ipynb create mode 100644 course/similarity/02_transformer_dense_vectors.ipynb rename course/{similarity_search/similarity_metrics.ipynb => similarity/03_similarity_metrics.ipynb} (100%) create mode 100644 course/similarity/04_calculating_similarity.ipynb create mode 100644 course/similarity/05_sentence_transformers.ipynb diff --git a/course/attention/03_bidirectional_attention.ipynb b/course/attention/03_bidirectional_attention.ipynb index 889eae4..a9dbaa0 100644 --- a/course/attention/03_bidirectional_attention.ipynb +++ b/course/attention/03_bidirectional_attention.ipynb @@ -6,10 +6,17 @@ "source": [ "# Bi-directional Attention\n", "\n", - "We've explored both dot-product attention, and self-attention. Where dot-product compared two sequences, and causal attention compared previous tokens from the *same sequence*, bidirectional attention compares tokens from the *same sequence* in both directions, subsequent and previous. This is as simple as performing the exact same operation that we performed for *self-attention*, but excluding the masking operation - allowing each word to be mapped to every other word in the same sequence. So, we could call this *bi-directional **self** attention*. This is particularly useful for masked language modeling - and is used in BERT (**Bidirectional Encoder** Representations from Transformers) - bidirectional self-attention refers to the *bidirectional encoder*, or the *BE* of BERT.\n", + "We've explored both dot-product attention, and self-attention. Where dot-product compared two sequences, and self attention compared previous tokens from the *same sequence*, bidirectional attention compares tokens from the *same sequence* in both directions, subsequent and previous. This is as simple as performing the exact same operation that we performed for *self-attention*, but excluding the masking operation - allowing each word to be mapped to every other word in the same sequence. So, we could call this *bi-directional **self** attention*. This is particularly useful for masked language modeling - and is used in BERT (**Bidirectional Encoder** Representations from Transformers) - bidirectional self-attention refers to the *bidirectional encoder*, or the *BE* of BERT.\n", "\n", "![Bidirectional Attention](../../assets/images/bidirectional_attention.png)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/course/introduction/02_transformer_heads.ipynb b/course/introduction/02_transformer_heads.ipynb index 8616e94..5fd947d 100644 --- a/course/introduction/02_transformer_heads.ipynb +++ b/course/introduction/02_transformer_heads.ipynb @@ -38,6 +38,13 @@ "\n", "![Q&A BERT showing additional start/end token classifiers](../../assets/images/qa_linear_bert.png)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/course/similarity/00_intro.ipynb b/course/similarity/00_intro.ipynb new file mode 100644 index 0000000..01ae31b --- /dev/null +++ b/course/similarity/00_intro.ipynb @@ -0,0 +1,32 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ML", + "language": "python", + "name": "ml" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/course/similarity/02_transformer_dense_vectors.ipynb b/course/similarity/02_transformer_dense_vectors.ipynb new file mode 100644 index 0000000..9a507d2 --- /dev/null +++ b/course/similarity/02_transformer_dense_vectors.ipynb @@ -0,0 +1,679 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Building Dense Vectors Using Transformers\n", + "\n", + "We will be using the [`sentence-transformers/stsb-distilbert-base`](https://huggingface.co/sentence-transformers/stsb-distilbert-base) model to build our dense vectors." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModel\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we initialize our model and tokenizer:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/stsb-distilbert-base')\n", + "model = AutoModel.from_pretrained('sentence-transformers/stsb-distilbert-base')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we tokenize a sentence just as we have been doing before:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "text = \"hello world what a time to be alive!\"\n", + "\n", + "tokens = tokenizer.encode_plus(text, max_length=128,\n", + " truncation=True, padding='max_length',\n", + " return_tensors='pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We process these tokens through our model:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BaseModelOutput(last_hidden_state=tensor([[[-0.9489, 0.6905, -0.2188, ..., 0.0161, 0.5874, -0.1449],\n", + " [-0.6643, 1.1984, -0.1346, ..., 0.4839, 0.6338, -0.5003],\n", + " [-0.3289, 0.6412, 0.2473, ..., -0.0965, 0.4298, 0.0515],\n", + " ...,\n", + " [-0.7853, 0.8094, -0.2639, ..., 0.2177, 0.3335, 0.1107],\n", + " [-0.7528, 0.6285, -0.0088, ..., 0.1024, 0.4585, 0.1720],\n", + " [-1.0754, 0.4878, -0.3458, ..., 0.2764, 0.5604, 0.1236]]],\n", + " grad_fn=), hidden_states=None, attentions=None)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outputs = model(**tokens)\n", + "outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The dense vector representations of our `text` are contained within the `outputs` **'last_hidden_state'** tensor, which we access like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.9489, 0.6905, -0.2188, ..., 0.0161, 0.5874, -0.1449],\n", + " [-0.6643, 1.1984, -0.1346, ..., 0.4839, 0.6338, -0.5003],\n", + " [-0.3289, 0.6412, 0.2473, ..., -0.0965, 0.4298, 0.0515],\n", + " ...,\n", + " [-0.7853, 0.8094, -0.2639, ..., 0.2177, 0.3335, 0.1107],\n", + " [-0.7528, 0.6285, -0.0088, ..., 0.1024, 0.4585, 0.1720],\n", + " [-1.0754, 0.4878, -0.3458, ..., 0.2764, 0.5604, 0.1236]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings = outputs.last_hidden_state\n", + "embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 128, 768])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After we have produced our dense vectors `embeddings`, we need to perform a *mean pooling* operation on them to create a single vector encoding. To do this mean pooling operation we will need to multiply each value in our `embeddings` tensor by it's respective `attention_mask` value - so that we ignore non-real tokens.\n", + "\n", + "To perform this operation, we first resize our `attention_mask` tensor:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 128])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_mask = tokens['attention_mask']\n", + "attention_mask.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 128, 768])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()\n", + "mask.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_mask" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([768])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask[0][0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we multiply the two tensors to apply the attention mask:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 128, 768])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "masked_embeddings = embeddings * mask\n", + "masked_embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.9489, 0.6905, -0.2188, ..., 0.0161, 0.5874, -0.1449],\n", + " [-0.6643, 1.1984, -0.1346, ..., 0.4839, 0.6338, -0.5003],\n", + " [-0.3289, 0.6412, 0.2473, ..., -0.0965, 0.4298, 0.0515],\n", + " ...,\n", + " [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "masked_embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we sum the remained of the embeddings along axis `1`:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 768])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summed = torch.sum(masked_embeddings, 1)\n", + "summed.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then sum the number of values that must be given attention in each position of the tensor:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 768])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summed_mask = torch.clamp(mask.sum(1), min=1e-9)\n", + "summed_mask.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,\n", + " 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summed_mask" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we calculate the mean as the sum of the embedding activations `summed` divided by the number of values that should be given attention in each position `summed_mask`:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "mean_pooled = summed / summed_mask" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-3.8485e-01, 7.8107e-01, -1.7720e-01, -1.4125e+00, -2.3358e-01,\n", + " 9.0891e-01, -7.8390e-02, 6.0347e-01, 6.7886e-02, -3.9842e-01,\n", + " 3.9223e-02, -4.6774e-01, -7.1848e-01, -1.1863e-01, -7.1194e-02,\n", + " 6.6017e-03, -1.4093e-01, 3.1271e-01, -6.5574e-01, -1.6470e-01,\n", + " -1.0026e-01, -3.8357e-01, 6.1278e-02, -7.3818e-01, -5.9918e-01,\n", + " 2.8855e-01, 8.6372e-01, 5.8388e-01, -3.5059e-02, 4.3197e-01,\n", + " -5.0111e-01, -4.3498e-01, 2.3498e-01, -3.7127e-01, -1.0044e+00,\n", + " 1.0000e+00, -2.1000e+00, -3.2251e-01, -1.6085e-01, -7.3701e-01,\n", + " 5.4928e-01, -1.2066e-01, 7.2698e-01, -5.0327e-02, -1.7545e+00,\n", + " 8.0573e-01, -5.0553e-01, -4.7172e-01, -1.6727e-01, 5.9727e-01,\n", + " 5.6203e-01, -3.6104e-01, -1.6429e-01, -5.5215e-01, -5.0417e-01,\n", + " 5.6187e-01, -1.1415e+00, 1.0771e+00, 5.5689e-01, -7.0632e-02,\n", + " -2.6932e-01, -6.8905e-01, 1.8093e-01, 3.1045e-01, 3.9036e-02,\n", + " 3.1064e-01, -4.4495e-01, -4.7363e-02, 1.7010e+00, 1.2346e-01,\n", + " 6.7439e-01, -6.4171e-01, 4.6368e-01, -1.0917e+00, -7.6345e-01,\n", + " -3.0914e-01, -3.0577e-01, 7.6974e-01, -1.3508e+00, -2.3315e-01,\n", + " -3.3893e-01, -7.1898e-01, -1.4306e-01, 5.6030e-01, -6.8665e-01,\n", + " -6.2314e-01, -1.6671e-01, 4.8549e-01, -4.1270e-01, 1.0744e+00,\n", + " 7.6589e-01, 3.2901e-01, 3.7139e-01, 4.4882e-01, 9.8538e-01,\n", + " 2.9093e-01, -8.5645e-01, -7.6587e-01, 3.4619e-01, 1.7652e-01,\n", + " 7.8148e-02, -2.3852e-01, 7.3973e-01, -3.3606e-01, -1.3271e-01,\n", + " -3.0365e-02, 9.4191e-01, -9.6431e-02, -4.2533e-01, -3.6001e-01,\n", + " 4.0952e-01, -3.5570e-01, -6.1359e-01, 1.0244e+00, -1.6946e-01,\n", + " 5.8016e-01, -3.5626e-01, -1.8950e-01, -1.3671e+00, -4.2409e-02,\n", + " 7.6267e-01, -3.8510e-01, -4.4530e-01, 6.0832e-01, -3.5163e-01,\n", + " -3.1655e-01, 5.4221e-01, -4.2778e-02, 1.8298e+00, -5.0055e-02,\n", + " 4.0329e-01, 3.8894e-01, -1.4479e-01, -5.0510e-01, 4.9653e-01,\n", + " 2.2225e-01, -1.7006e-01, 1.6015e-01, -5.6461e-01, -6.1269e-01,\n", + " 6.5336e-01, 7.9370e-01, -1.3359e-01, -2.9643e-01, 6.5500e-01,\n", + " 4.5574e-01, 1.3022e-01, 8.5983e-01, 9.5493e-02, -4.1140e-01,\n", + " -6.0792e-01, -6.8995e-01, -8.0654e-01, -3.3334e-01, -6.1655e-01,\n", + " -6.7164e-01, -7.9168e-01, -6.0887e-01, -6.7842e-01, 7.7270e-01,\n", + " -1.8380e-01, 4.2944e-01, 8.2487e-01, -3.1597e-01, -1.6402e+00,\n", + " 8.4033e-01, 6.2706e-01, 4.5629e-01, 3.6232e-01, -4.5201e-01,\n", + " -1.8216e-02, 6.0873e-02, -7.5308e-01, 1.0445e+00, -1.2082e+00,\n", + " -3.3916e-01, 5.6531e-01, -8.5100e-01, -3.4787e-01, 1.2856e-01,\n", + " -5.8088e-01, -5.2811e-01, -1.2420e-02, 6.3882e-02, -6.7478e-02,\n", + " -1.7361e-01, 7.0940e-01, -1.4097e+00, -1.3953e-01, -1.9912e-01,\n", + " 4.7711e-01, 8.0363e-01, 3.7343e-01, -5.2863e-01, 3.9637e-01,\n", + " 9.4878e-01, 1.2467e-01, -1.7886e-01, 7.9958e-01, 3.6676e-02,\n", + " 3.2056e-01, 1.0148e+00, 6.4841e-01, -1.6269e-01, 3.3561e-01,\n", + " -5.0149e-01, 2.4113e-01, -8.9913e-01, -1.2511e-01, -3.3696e-01,\n", + " 3.0698e-01, -2.2303e-01, -3.8996e-01, -2.8901e-01, -3.8460e-01,\n", + " 8.8542e-01, -1.6964e-01, 2.4618e-01, -8.4116e-01, -3.2811e-01,\n", + " -3.2727e-01, -2.2230e-02, 4.5131e-01, -4.1267e-01, 9.1945e-01,\n", + " -6.3390e-01, 9.4229e-01, 1.3876e-01, -2.5902e-02, -4.5191e-01,\n", + " 5.7052e-02, 2.4740e-01, -5.7986e-01, 7.7694e-02, 7.6410e-02,\n", + " -3.4006e-01, 4.3327e-01, -3.9236e-01, 4.5135e-01, -5.3925e-01,\n", + " -5.7638e-01, -5.1190e-01, -2.4838e-02, -2.9940e-01, -2.6119e-01,\n", + " -6.8238e-01, -1.0826e-01, 2.7870e-01, 5.2347e-01, -1.2790e+00,\n", + " -7.5903e-01, -4.1540e-01, 2.7823e-01, 1.2852e-01, 7.8037e-01,\n", + " -8.0996e-01, -7.2413e-01, -8.6791e-01, -8.5757e-01, 1.7594e-01,\n", + " 4.0083e-01, -7.3397e-01, 3.9002e-01, 1.1243e-01, 1.2089e-01,\n", + " -1.0793e-01, 1.1323e-01, -2.7789e-01, -7.9843e-03, 2.1983e-02,\n", + " -8.3703e-01, 3.4330e-01, 1.3543e-01, 5.7925e-02, 1.9617e+00,\n", + " 8.3412e-01, -7.2922e-01, 4.6160e-01, -3.0357e-01, -3.4166e-01,\n", + " -6.9856e-01, 9.4905e-02, -1.1436e+00, 4.2823e-02, -3.4007e-01,\n", + " 1.7994e-01, -6.0481e-01, -9.6435e-01, -4.9273e-01, -1.3865e-01,\n", + " -1.5011e+00, 3.8161e-01, -3.7843e-01, -2.7537e-01, 8.0850e-02,\n", + " 3.0396e-01, 6.3909e-03, -1.0301e+00, -7.4074e-01, 6.2466e-02,\n", + " 2.3288e-01, -9.3533e-01, 1.0078e-01, -1.2119e+00, -6.6829e-01,\n", + " -5.5835e-02, -9.2504e-02, -4.9092e-01, 1.1554e-01, 4.5436e-01,\n", + " 1.2235e+00, 5.7517e-01, -1.0007e+00, 6.2173e-01, -6.6875e-02,\n", + " 1.2660e+00, 4.9329e-01, -1.5459e-01, 9.2993e-02, -5.6998e-02,\n", + " -2.8559e-02, 5.2734e-01, -3.0664e-01, -8.0383e-01, -1.6655e-01,\n", + " -5.5859e-01, -5.1713e-01, 4.6910e-02, -1.1429e+00, -1.6247e-01,\n", + " -1.2338e-01, -3.4689e-01, -3.5227e-01, -3.8736e-01, 1.0393e+00,\n", + " -1.6471e-01, -1.3883e-01, 9.1784e-01, -7.2758e-01, -1.5185e-01,\n", + " 2.8702e-01, -2.0967e-01, 5.5545e-01, 1.8944e-01, -5.0340e-01,\n", + " -1.0897e+00, -7.5433e-01, -1.3625e+00, -3.8772e-02, -7.7702e-01,\n", + " 2.5420e-01, 3.1660e-01, -8.6211e-01, -2.3615e-01, 4.6479e-03,\n", + " 5.1776e-01, 4.7276e-01, -7.3881e-02, -5.5788e-01, 2.4043e-01,\n", + " 9.5054e-01, 2.6625e-02, -4.6604e-01, -3.3385e-01, 4.4900e-01,\n", + " 1.1014e+00, 5.9351e-01, -1.2061e-01, 2.1053e-01, 7.3098e-01,\n", + " -2.3732e-02, 2.6349e-01, 3.3863e-01, 5.9553e-01, -3.3448e-01,\n", + " 1.2544e-01, 3.3026e-01, -1.5698e-01, -3.7932e-01, -2.5078e-01,\n", + " -2.9495e-01, 8.2592e-02, 1.8376e-01, -6.8231e-01, -8.6327e-02,\n", + " -5.7801e-01, -8.4704e-02, -1.4150e-01, 9.1605e-01, -5.6759e-01,\n", + " -1.0993e-01, -1.5896e-02, 6.2933e-01, 2.1628e-01, 1.1261e-01,\n", + " 6.5828e-01, 4.5636e-01, 1.0936e+00, 7.4275e-01, -3.7315e-01,\n", + " 3.7326e-01, 1.0809e+00, -2.4348e-01, -4.9122e-01, 1.1691e+00,\n", + " 1.0116e+00, -2.2179e-01, -8.4004e-02, 4.4811e-01, 8.3704e-01,\n", + " -1.4922e-01, -7.3480e-02, 2.8369e-01, 5.3243e-01, 3.5504e-02,\n", + " -7.2948e-01, 2.2285e-01, -8.1695e-01, -9.8309e-02, 1.6787e-02,\n", + " -1.0060e+00, 5.8846e-02, -1.1733e-01, -2.5029e-03, 9.7850e-01,\n", + " 4.2993e-01, 5.5168e-01, 7.5765e-01, 4.1643e-01, -7.7879e-01,\n", + " 6.5853e-01, -2.7104e-01, -2.1195e-01, 2.6836e-01, 8.9252e-02,\n", + " -2.2026e-01, 7.0055e-01, -5.0542e-01, -9.2811e-01, 2.8497e-01,\n", + " -3.2909e-01, 9.0162e-01, 5.6190e-01, 3.8479e-02, 7.6101e-01,\n", + " -2.4245e-02, -5.2505e-01, 4.9243e-01, -1.1323e+00, -7.9398e-02,\n", + " 8.9294e-01, -4.1039e-01, -4.2587e-01, 5.6288e-01, -3.1121e-01,\n", + " -5.0377e-02, -7.7956e-01, 7.0310e-01, -1.1243e-01, 3.1637e-01,\n", + " 1.5981e-01, -9.6209e-03, -1.0382e+00, -1.8747e-03, -9.9495e-02,\n", + " -1.5131e+00, 7.5718e-01, -7.7793e-02, 1.0319e+00, 5.2133e-01,\n", + " -3.2082e-01, -1.3737e-01, 1.0844e+00, -3.7648e-01, -3.4650e-02,\n", + " 1.3097e-01, -3.5184e-01, -8.1428e-01, 3.4189e-01, 6.7281e-02,\n", + " -1.9175e-01, 2.2250e-01, 4.0790e-01, -4.0171e-02, 9.3394e-01,\n", + " 1.4848e-01, 9.4151e-02, -6.1521e-01, 1.8073e-01, -9.3871e-01,\n", + " -3.5805e-01, -1.1437e-01, 9.8406e-01, -1.3756e+00, 9.7456e-02,\n", + " -2.3249e-01, 9.5018e-01, -2.1394e-01, -2.3394e-01, 5.0237e-01,\n", + " 2.5898e-01, -2.2813e-01, -2.3680e-01, -2.3152e-01, -9.8057e-01,\n", + " 3.9108e-01, 5.8315e-01, 1.6551e-01, -3.6449e-01, 4.2075e-01,\n", + " 9.3581e-01, -4.6776e-01, -2.2665e-02, 3.7928e-01, -6.1125e-01,\n", + " -4.0730e-01, 9.0755e-01, 1.0523e+00, -1.9673e-01, -6.0428e-02,\n", + " -5.5663e-02, 3.8640e-01, -1.2758e-01, 6.1613e-01, -3.9228e-01,\n", + " 9.0591e-01, -4.3536e-01, -7.4162e-02, 1.0847e-01, -6.4019e-02,\n", + " 6.2278e-01, 3.7997e-01, -1.1579e-01, -1.9312e+00, 7.1141e-01,\n", + " 1.1751e-01, -4.1557e-01, -7.7247e-01, 6.3692e-01, 5.3097e-01,\n", + " 9.7168e-02, -6.8854e-02, -8.8752e-01, 4.2003e-01, 1.4736e-01,\n", + " 4.4949e-01, 1.0757e-01, 8.5666e-01, 2.1895e-01, -1.4616e-01,\n", + " -2.1148e-01, 7.3091e-02, 5.6748e-01, 3.9416e-01, 2.8383e-02,\n", + " 1.0420e+00, -9.9249e-02, -5.5125e-01, 7.3612e-02, 1.1771e+00,\n", + " -5.5362e-01, -1.0581e-01, -4.2232e-01, -1.5856e+00, 7.3779e-01,\n", + " -1.4219e-01, -1.0619e+00, -6.8308e-01, 1.3319e-02, 4.1730e-01,\n", + " -1.1350e+00, 2.5110e-01, 4.9541e-01, 1.0239e-01, -7.1889e-01,\n", + " 1.0615e-01, 7.6836e-01, -6.0918e-02, -3.6846e-01, 5.1103e-02,\n", + " -6.9368e-01, -1.6377e-01, 7.2992e-01, -2.7181e-01, -1.7474e-01,\n", + " 6.6675e-01, -2.4677e-01, -2.8554e-01, -7.7832e-02, 8.7495e-02,\n", + " 2.1369e-01, 8.7279e-01, -9.8810e-02, -5.0639e-01, -4.2866e-01,\n", + " -5.1867e-01, 4.2720e-01, 3.1696e-01, -2.9805e-01, -8.3426e-01,\n", + " -1.0784e+00, -7.7276e-01, 4.9140e-01, 1.1272e+00, 1.3698e-02,\n", + " -6.8530e-02, -1.1509e-01, -6.5638e-01, 7.9699e-01, -2.6068e-01,\n", + " 1.0395e+00, -4.7972e-01, -1.4439e-01, 7.8087e-01, -5.9054e-01,\n", + " 2.1602e-01, -7.4449e-01, -1.3328e-01, -1.4614e-01, 8.9816e-01,\n", + " -1.0125e+00, -6.5561e-01, 7.6670e-01, -2.8419e-01, -1.3880e-01,\n", + " -7.1945e-01, -5.1779e-01, 1.4314e-01, -2.3534e-01, -5.9846e-01,\n", + " 6.0434e-02, -6.3184e-02, -1.5664e+00, -1.9544e-01, 2.0409e-01,\n", + " -1.0337e+00, 9.1216e-01, 2.3952e-01, 1.0880e-01, -2.0045e-01,\n", + " 8.4616e-01, -7.5020e-02, 3.4787e-01, -1.5094e+00, -2.5039e-01,\n", + " -6.5037e-02, 6.9634e-01, -2.6770e-01, 8.9710e-02, -4.8853e-01,\n", + " 7.0874e-01, -7.6796e-01, 8.4987e-01, 4.1382e-01, -4.0460e-01,\n", + " 2.8681e-01, 1.0482e+00, 1.6342e-01, 8.9450e-02, -2.9139e-01,\n", + " -6.0596e-01, -1.0153e-01, -3.3035e-01, -4.3888e-01, -6.9056e-02,\n", + " 5.0943e-02, 3.7704e-01, 6.6890e-02, 5.8372e-01, 3.2396e-01,\n", + " -2.4983e-01, -2.9541e-01, 3.7929e-01, -3.1190e-01, -1.3260e-01,\n", + " 5.0000e-01, 6.4270e-01, -3.0923e-01, -2.9641e-01, -8.8432e-01,\n", + " 4.0610e-01, 6.5061e-01, -1.2177e-02, 1.1644e+00, -5.9447e-01,\n", + " -1.8063e-01, 6.1685e-01, -1.1272e-01, -6.0815e-01, -7.2103e-01,\n", + " 7.6628e-01, -3.2992e-01, -5.0285e-01, 5.1563e-01, -4.1571e-01,\n", + " -4.7703e-01, -2.8721e-01, 8.4478e-01, -3.9162e-01, 9.2312e-02,\n", + " 7.8206e-01, 2.1263e-01, -6.4607e-01, 9.6211e-01, 1.2251e-01,\n", + " -9.9896e-01, 1.4947e-01, -2.5206e-01, -3.9582e-01, 8.5691e-01,\n", + " 3.5398e-01, -8.7718e-01, 2.3607e-01, -2.1513e-01, -6.1426e-01,\n", + " -1.4891e-01, 4.2167e-01, -9.3733e-01, 7.1058e-01, -1.5226e-01,\n", + " 1.3321e+00, 1.2884e-01, 1.4089e-01, 6.7874e-01, -5.8004e-01,\n", + " 1.5258e-01, 5.5926e-01, -1.4397e-01, -3.4066e-01, -6.2638e-01,\n", + " 1.5209e-01, 1.1089e-01, 6.4280e-02, 5.5102e-01, -1.1813e-01,\n", + " -4.6605e-01, 3.0925e-01, -3.5231e-01, 4.1569e-01, 1.4494e+00,\n", + " -3.5271e-01, -9.2626e-01, -3.0457e-01, -2.8597e-01, -4.8825e-01,\n", + " -2.0081e-01, 6.2255e-02, -6.3340e-01, 4.3752e-02, 3.5119e-01,\n", + " 7.5687e-01, -3.1523e-01, 2.9945e-01, -7.5232e-01, 1.2727e-01,\n", + " -5.1653e-01, -1.7894e-02, 4.0687e-01, -1.9404e-01, 3.7166e-01,\n", + " 6.0207e-01, 9.7962e-01, 1.3890e-01, 6.2343e-01, 4.1314e-01,\n", + " 4.1372e-01, 5.6574e-01, -4.6809e-01]], grad_fn=)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mean_pooled" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And that is how we calculate our dense similarity vector." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ML", + "language": "python", + "name": "ml" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/course/similarity_search/similarity_metrics.ipynb b/course/similarity/03_similarity_metrics.ipynb similarity index 100% rename from course/similarity_search/similarity_metrics.ipynb rename to course/similarity/03_similarity_metrics.ipynb diff --git a/course/similarity/04_calculating_similarity.ipynb b/course/similarity/04_calculating_similarity.ipynb new file mode 100644 index 0000000..6a70f81 --- /dev/null +++ b/course/similarity/04_calculating_similarity.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Calculating Similarity\n", + "\n", + "When calculating similarity between our transformer embedded vectors, we can use any of the *three* similarity metrics already covered.\n", + "\n", + "But first, let's create some embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "sentences = [\n", + " \"Three years later, the coffin was still full of Jello.\",\n", + " \"The fish dreamed of escaping the fishbowl and into the toilet where he saw his friend go.\",\n", + " \"The person box was packed with jelly many dozens of months later.\",\n", + " \"Standing on one's head at job interviews forms a lasting impression.\",\n", + " \"It took him a month to finish the meal.\",\n", + " \"It turns out you don't need all that stuff you insisted you did.\"\n", + " \"He found a leprechaun in his walnut shell.\"\n", + "]\n", + "\n", + "# thanks to https://randomwordgenerator.com/sentence.php" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModel\n", + "import torch\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/stsb-distilbert-base')\n", + "model = AutoModel.from_pretrained('sentence-transformers/stsb-distilbert-base')\n", + "\n", + "# tokenize sequences\n", + "tokens = []\n", + "\n", + "for sentence in sentences:\n", + " tokens.append(\n", + " tokenizer.encode_plus(sentence, max_length=128, truncation=True,\n", + " padding='max_length', return_tensors='pt')\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "expected Tensor as element 0 in argument 0, but got BatchEncoding", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m: expected Tensor as element 0 in argument 0, but got BatchEncoding" + ] + } + ], + "source": [ + "tokens = torch.stack(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ML", + "language": "python", + "name": "ml" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/course/similarity/05_sentence_transformers.ipynb b/course/similarity/05_sentence_transformers.ipynb new file mode 100644 index 0000000..06b65ef --- /dev/null +++ b/course/similarity/05_sentence_transformers.ipynb @@ -0,0 +1,41 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Embeddings With Sentence-Transformers\n", + "\n", + "We've worked through creating our embeddings using the `transformers` library - and at times it can be quite involved. Now, it's important to understand the steps, but we can make life easier by using the `sentence-transformers` library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ML", + "language": "python", + "name": "ml" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}