Skip to content

Commit f9f1a16

Browse files
committed
Attention basics
1 parent 68cb7f4 commit f9f1a16

File tree

5 files changed

+728
-0
lines changed

5 files changed

+728
-0
lines changed

attention/Attention_Basics.ipynb

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Attention Basics\n",
8+
"In this notebook, we look at how attention is implemented. We will focus on implementing attention in isolation from a larger model. That's because when implementing attention in a real-world model, a lot of the focus goes into piping the data and juggling the various vectors rather than the concepts of attention themselves.\n",
9+
"\n",
10+
"We will implement attention scoring as well as calculating an attention context vector.\n",
11+
"\n",
12+
"## Attention Scoring\n",
13+
"### Inputs to the scoring function\n",
14+
"Let's start by looking at the inputs we'll give to the scoring function. We will assume we're in the first step in the decoding phase. The first input to the scoring function is the hidden state of decoder (assuming a toy RNN with three hidden nodes -- not usable in real life, but easier to illustrate):"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"dec_hidden_state = [5,1,20]"
24+
]
25+
},
26+
{
27+
"cell_type": "markdown",
28+
"metadata": {},
29+
"source": [
30+
"Let's visualize this vector:"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"%matplotlib inline\n",
40+
"import numpy as np\n",
41+
"import matplotlib.pyplot as plt\n",
42+
"import seaborn as sns\n",
43+
"\n",
44+
"# Let's visualize our decoder hidden state\n",
45+
"plt.figure(figsize=(1.5, 4.5))\n",
46+
"sns.heatmap(np.transpose(np.matrix(dec_hidden_state)), annot=True, cmap=sns.light_palette(\"purple\", as_cmap=True), linewidths=1)"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"metadata": {},
52+
"source": [
53+
"Our first scoring function will score a single annotation (encoder hidden state), which looks like this:"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"annotation = [3,12,45] #e.g. Encoder hidden state"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"# Let's visualize the single annotation\n",
72+
"plt.figure(figsize=(1.5, 4.5))\n",
73+
"sns.heatmap(np.transpose(np.matrix(annotation)), annot=True, cmap=sns.light_palette(\"orange\", as_cmap=True), linewidths=1)"
74+
]
75+
},
76+
{
77+
"cell_type": "markdown",
78+
"metadata": {},
79+
"source": [
80+
"### IMPLEMENT: Scoring a Single Annotation\n",
81+
"Let's calculate the dot product of a single annotation. NumPy's [dot()](https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html) is a good candidate for this operation"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"def single_dot_attention_score(dec_hidden_state, enc_hidden_state):\n",
91+
" # TODO: return the dot product of the two vectors\n",
92+
" return \n",
93+
" \n",
94+
"single_dot_attention_score(dec_hidden_state, annotation)"
95+
]
96+
},
97+
{
98+
"cell_type": "markdown",
99+
"metadata": {},
100+
"source": [
101+
"\n",
102+
"### Annotations Matrix\n",
103+
"Let's now look at scoring all the annotations at once. To do that, here's our annotation matrix:"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"metadata": {},
110+
"outputs": [],
111+
"source": [
112+
"annotations = np.transpose([[3,12,45], [59,2,5], [1,43,5], [4,3,45.3]])"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"metadata": {},
118+
"source": [
119+
"And it can be visualized like this (each column is a hidden state of an encoder time step):"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {},
126+
"outputs": [],
127+
"source": [
128+
"# Let's visualize our annotation (each column is an annotation)\n",
129+
"ax = sns.heatmap(annotations, annot=True, cmap=sns.light_palette(\"orange\", as_cmap=True), linewidths=1)"
130+
]
131+
},
132+
{
133+
"cell_type": "markdown",
134+
"metadata": {},
135+
"source": [
136+
"### IMPLEMENT: Scoring All Annotations at Once\n",
137+
"Let's calculate the scores of all the annotations in one step using matrix multiplication. Let's continue to us the dot scoring method\n",
138+
"\n",
139+
"<img src=\"images/scoring_functions.png\" />\n",
140+
"\n",
141+
"To do that, we'll have to transpose `dec_hidden_state` and [matrix multiply](https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html) it with `annotations`."
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": null,
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"def dot_attention_score(dec_hidden_state, annotations):\n",
151+
" # TODO: return the product of dec_hidden_state transpose and enc_hidden_states\n",
152+
" return \n",
153+
" \n",
154+
"attention_weights_raw = dot_attention_score(dec_hidden_state, annotations)\n",
155+
"attention_weights_raw"
156+
]
157+
},
158+
{
159+
"cell_type": "markdown",
160+
"metadata": {},
161+
"source": [
162+
"Looking at these scores, can you guess which of the four vectors will get the most attention from the decoder at this time step?\n",
163+
"\n",
164+
"## Softmax\n",
165+
"Now that we have our scores, let's apply softmax:\n",
166+
"<img src=\"images/softmax.png\" />"
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": null,
172+
"metadata": {},
173+
"outputs": [],
174+
"source": [
175+
"def softmax(x):\n",
176+
" x = np.array(x, dtype=np.float128)\n",
177+
" e_x = np.exp(x)\n",
178+
" return e_x / e_x.sum(axis=0) \n",
179+
"\n",
180+
"attention_weights = softmax(attention_weights_raw)\n",
181+
"attention_weights"
182+
]
183+
},
184+
{
185+
"cell_type": "markdown",
186+
"metadata": {},
187+
"source": [
188+
"Even when knowing which annotation will get the most focus, it's interesting to see how drastic softmax makes the end score become. The first and last annotation had the respective scores of 927 and 929. But after softmax, the attention they'll get is 0.12 and 0.88 respectively.\n",
189+
"\n",
190+
"# Applying the scores back on the annotations\n",
191+
"Now that we have our scores, let's multiply each annotation by its score to proceed closer to the attention context vector. This is the multiplication part of this formula (we'll tackle the summation part in the latter cells)\n",
192+
"\n",
193+
"<img src=\"images/Context_vector.png\" />"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": null,
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"def apply_attention_scores(attention_weights, annotations):\n",
203+
" # TODO: Multiple the annotations by their weights\n",
204+
" return\n",
205+
"\n",
206+
"applied_attention = apply_attention_scores(attention_weights, annotations)\n",
207+
"applied_attention"
208+
]
209+
},
210+
{
211+
"cell_type": "markdown",
212+
"metadata": {},
213+
"source": [
214+
"Let's visualize how the context vector looks now that we've applied the attention scores back on it:"
215+
]
216+
},
217+
{
218+
"cell_type": "code",
219+
"execution_count": null,
220+
"metadata": {},
221+
"outputs": [],
222+
"source": [
223+
"# Let's visualize our annotations after applying attention to them\n",
224+
"ax = sns.heatmap(applied_attention, annot=True, cmap=sns.light_palette(\"orange\", as_cmap=True), linewidths=1)"
225+
]
226+
},
227+
{
228+
"cell_type": "markdown",
229+
"metadata": {},
230+
"source": [
231+
"Contrast this with the raw annotations visualized earlier in the notebook, and we can see that the second and third annotations (columns) have been nearly wiped out. The first annotation maintains some of its value, and the fourth annotation is the most pronounced.\n",
232+
"\n",
233+
"# Calculating the Attention Context Vector\n",
234+
"All that remains to produce our attention context vector now is to sum up the four columns to produce a single attention context vector\n"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"def calculate_attention_vector(applied_attention):\n",
244+
" return np.sum(applied_attention, axis=1)\n",
245+
"\n",
246+
"attention_vector = calculate_attention_vector(applied_attention)\n",
247+
"attention_vector"
248+
]
249+
},
250+
{
251+
"cell_type": "code",
252+
"execution_count": null,
253+
"metadata": {
254+
"scrolled": false
255+
},
256+
"outputs": [],
257+
"source": [
258+
"# Let's visualize the attention context vector\n",
259+
"plt.figure(figsize=(1.5, 4.5))\n",
260+
"sns.heatmap(np.transpose(np.matrix(attention_vector)), annot=True, cmap=sns.light_palette(\"Blue\", as_cmap=True), linewidths=1)"
261+
]
262+
},
263+
{
264+
"cell_type": "markdown",
265+
"metadata": {},
266+
"source": [
267+
"Now that we have the context vector, we can concatenate it with the hidden state and pass it through a hidden layer to produce the the result of this decoding time step."
268+
]
269+
}
270+
],
271+
"metadata": {
272+
"anaconda-cloud": {},
273+
"kernelspec": {
274+
"display_name": "Python 3",
275+
"language": "python",
276+
"name": "python3"
277+
},
278+
"language_info": {
279+
"codemirror_mode": {
280+
"name": "ipython",
281+
"version": 3
282+
},
283+
"file_extension": ".py",
284+
"mimetype": "text/x-python",
285+
"name": "python",
286+
"nbconvert_exporter": "python",
287+
"pygments_lexer": "ipython3",
288+
"version": "3.6.3"
289+
}
290+
},
291+
"nbformat": 4,
292+
"nbformat_minor": 1
293+
}

0 commit comments

Comments
 (0)