{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Artificial Neural Networks" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pyaudi import gdual_double as gdual\n", "from pyaudi import sin, cos, tanh\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the architecture of the network:\n", "\n", " Inputs: 3\n", " Hidden layers: 2 with 5 units/layer\n", " Outputs: 1\n", "\n", "We will need the first order derivatives" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n_units = [3, 5, 5, 1]\n", "order = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Definition of parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create symbolic variables for the weights with values drawn from $\\mathcal N(0,1)$" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def initialize_weights(n_units, order):\n", "\n", " weights = []\n", "\n", " for layer in range(1, len(n_units)):\n", " weights.append([])\n", " for unit in range(n_units[layer]):\n", " weights[-1].append([])\n", " for prev_unit in range(n_units[layer-1]): \n", " symname = 'w_{{({0},{1},{2})}}'.format(layer, unit, prev_unit)\n", " w = gdual(np.random.randn(), symname , order)\n", " weights[-1][-1].append(w)\n", " \n", " return weights\n", "\n", "weights = initialize_weights(n_units, order) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And for the biases, initialized to 1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def initialize_biases(n_units, order):\n", "\n", " biases = []\n", "\n", " for layer in range(1, len(n_units)):\n", " biases.append([])\n", " for unit in range(n_units[layer]):\n", " symname = 'b_{{({0},{1})}}'.format(layer, unit)\n", " b = gdual(1, symname , order)\n", " biases[-1].append(b)\n", " \n", " return biases\n", "\n", "biases = initialize_biases(n_units, order) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Neural network as a gdual expression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a function which output is the expression (*gdual*) corresponding to the neural network" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def N_f(inputs, w, b):\n", " \n", " prev_layer_outputs = inputs\n", " \n", " #Hidden layers\n", " for layer in range(len(weights)):\n", " \n", " this_layer_outputs = []\n", " \n", " for unit in range(len(weights[layer])):\n", " \n", " unit_output = 0\n", " unit_output += b[layer][unit]\n", " \n", " for prev_unit,prev_output in enumerate(prev_layer_outputs):\n", " unit_output += w[layer][unit][prev_unit] * prev_output\n", " \n", " if layer != len(weights)-1:\n", " unit_output = tanh(unit_output)\n", " \n", " this_layer_outputs.append(unit_output)\n", " \n", " prev_layer_outputs = this_layer_outputs\n", "\n", " return prev_layer_outputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The new function can be used to compute the output of the network given any input:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "N(x) = 2.565934265083211\n" ] } ], "source": [ "x1 = 1\n", "x2 = 2\n", "x3 = 4\n", "\n", "N = N_f([x1,x2, x3], weights, biases)[0]\n", "print('N(x) = {0}'.format(N.constant_cf))\n", "#N" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Loss function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our case the desired output of the network will be: $y(\\mathcal x)= x_1x_2 + 0.5x_3 +2$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 6.0\n" ] } ], "source": [ "def y_f(x):\n", " return x[0]*x[1] + 0.5*x[2] + 2\n", "\n", "y = y_f([x1,x2,x3])\n", "print('y = {0}'.format(y))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training process will seek to minimize a loss function corresponding to the quadratic error. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss = 11.792807471729587\n" ] } ], "source": [ "def loss_f(N,y):\n", " return (N-y)**2\n", "\n", "loss = loss_f(N, y)\n", "print('loss = {0}'.format(loss.constant_cf))\n", "#loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function loss_f accepts a gdual (N) and a float (y) and the output still a gdual, which will make it possible to compute its derivatives wrt any of the parameters of the network:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "w_{(2,4,0)}\n", "w_{(2,4,1)}\n", "w_{(2,4,2)}\n", "w_{(2,4,3)}\n", "w_{(2,4,4)}\n", "w_{(3,0,0)}\n", "w_{(3,0,1)}\n", "w_{(3,0,2)}\n", "w_{(3,0,3)}\n", "w_{(3,0,4)}\n" ] } ], "source": [ "for symbol in loss.symbol_set[-10:]:\n", " print(symbol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Training (Gradient descent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We update the weights with gradient descent using the first order derivatives of the loss function with respect to the weights (and biases)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def GD_update(loss, w, b, lr):\n", " \n", " #Update weights\n", " for layer in range(len(w)):\n", " for unit in range(len(w[layer])):\n", " for prev_unit in range(len(w[layer][unit])):\n", " \n", " weight = w[layer][unit][prev_unit]\n", " if weight.symbol_set[0] in loss.symbol_set:\n", " symbol_idx = loss.symbol_set.index(weight.symbol_set[0])\n", " d_idx = [0]*loss.symbol_set_size \n", " d_idx[symbol_idx] = 1\n", " \n", " # eg. if d_idx = [1,0,0,0,...] get get the derivatives of loss wrt\n", " # the first symbol (variable) in loss.symbol_set\n", " w[layer][unit][prev_unit] -= loss.get_derivative(d_idx) * lr\n", "\n", " #Update biases\n", " for i in range(len(b)):\n", " for j in range(len(b[layer])):\n", " \n", " bias = b[layer][unit]\n", " if bias.symbol_set[0] in loss.symbol_set:\n", " symbol_idx = loss.symbol_set.index(bias.symbol_set[0])\n", " d_idx = [0]*loss.symbol_set_size \n", " d_idx[symbol_idx] = 1\n", " b[layer][unit] -= loss.get_derivative(d_idx) * lr\n", "\n", " return w,b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After one GD step the loss function decrases" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss = 5.648639112617553\n" ] } ], "source": [ "weights, biases = GD_update(loss, weights, biases, 0.01) \n", "\n", "N = N_f([x1,x2,x3], weights, biases)[0]\n", "loss = loss_f(N, y)\n", "print('loss = {0}'.format(loss.constant_cf))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Doing the same for several data points" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "X = -1+2*np.random.rand(10,3)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 49, training loss: 0.10428705104913479\n" ] }, { "data": { "text/plain": [ "Text(0,0.5,'loss')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XmcHWWd7/HPr0+f3rf0krU7+wIB\nkiBtJAQUQdlkwFFQcAPEy9VBB5c7c9XrdUbuOOMyVwedxUFwgBlkGRTNKMq+LwlNTIghBEJISJOl\nO1sv6b37N39UpWk6vQW6urrP+b5fr/M6VXWeU/2rF02+XU9VPY+5OyIiIgAZcRcgIiLjh0JBRER6\nKRRERKSXQkFERHopFEREpJdCQUREeikURESkl0JBRER6KRRERKRXZtwFHK3y8nKfPXt23GWIiEwo\nzz333F53rxiu3YQLhdmzZ1NTUxN3GSIiE4qZbR9JO3UfiYhIL4WCiIj0UiiIiEgvhYKIiPRSKIiI\nSC+FgoiI9FIoiIhIr7QJhc27m/i7322iqa0z7lJERMattAmFHftb+NdHt/LSnua4SxERGbciDwUz\nS5jZH8zsNwN8lm1md5jZFjNbbWazo6pj0dRCAF7e0xTVjxARmfDG4kzhGmDTIJ9dCRxw9/nAD4Hv\nRlXEjJJccpMJNisUREQGFWkomFkl8AHghkGaXAjcHC7fBZxpZhZFLRkZxsIpBbys7iMRkUFFfabw\nD8BfAj2DfD4D2AHg7l1AA1AWVTELphTqTEFEZAiRhYKZnQ/UuftzQzUbYJsPsK+rzKzGzGrq6+vf\nck2LphRS39TOgUMdb3kfIiKpLMozhZXABWa2DbgdOMPM/qNfm1qgCsDMMoFiYH//Hbn79e5e7e7V\nFRXDDgc+qIXhxeaXdLYgIjKgyELB3b/m7pXuPhu4BHjI3T/Rr9kq4LJw+aKwzRFnCqNl4ZQCAF6q\n03UFEZGBjPkkO2Z2LVDj7quAG4F/N7MtBGcIl0T5s6cW5VCYk8lLu3WmICIykDEJBXd/BHgkXP5m\nn+1twMVjUQOAmbFQF5tFRAaVNk80H7ZwSiEv72kiwl4qEZEJKw1DoYADLZ3UN7fHXYqIyLiTdqGw\naMrh4S50sVlEpL+0C4UFYShs1sVmEZEjpF0olBdkUZqfxct1CgURkf7SLhTMjAWTC3SmICIygLQL\nBQiG0X55T7PuQBIR6SctQ2HhlEKa2rvY1dAWdykiIuNK2oYCoIfYRET6SdNQCMZA0ixsIiJvlpah\nUJKXxeTCbDbv1rMKIiJ9pWUoQHixWbelioi8SdqGwoLJwR1IPT26A0lE5LC0DYVFUwto7eym9kBr\n3KWIiIwbaRsKC3QHkojIEdI3FCaHs7ApFEREekUWCmaWY2ZrzGy9mW00s28N0OZyM6s3s3Xh6zNR\n1dNfYU6SGSW5CgURkT6inHmtHTjD3ZvNLAk8YWa/c/dn+rW7w90/H2Edg1o4RWMgiYj0FdmZggcO\nPwiQDF/j6lafhVML2Vp/iK7unrhLEREZFyK9pmBmCTNbB9QB97v76gGafdjMnjezu8ysKsp6+ls4\nuZCO7h627WsZyx8rIjJuRRoK7t7t7suASmC5mR3fr8l/AbPdfQnwAHDzQPsxs6vMrMbMaurr60et\nvkVTD8/Cpi4kEREYo7uP3P0g8AhwTr/t+9z98GTJPwVOGuT717t7tbtXV1RUjFpd8yoKMNNtqSIi\nh0V591GFmZWEy7nA+4AX+7WZ1mf1AmBTVPUMJDcrwazSPN2BJCISivLuo2nAzWaWIAifO939N2Z2\nLVDj7quAPzezC4AuYD9weYT1DGjBlEJe2qOB8UREIMJQcPfngRMH2P7NPstfA74WVQ0jsWhKIQ+9\nWEd7VzfZmYk4SxERiV3aPtF82MKphXT3OK/uPRR3KSIisVMohBPu6CE2ERGFAvMqCsjOzGBDbUPc\npYiIxC7tQyGZyOD4GcWsrz0YdykiIrFL+1AAWFJZzIbXGzTchYikPYUCsKyqhLbOHl6u062pIpLe\nFArA0soSANbvUBeSiKQ3hQIwqyyP4tykriuISNpTKABmxpLKYtbv0B1IIpLeFAqhZVUlbN7TRGtH\nd9yliIjERqEQWlJZQnePs3GnzhZEJH0pFEJLK4sBWKeLzSKSxhQKoclFOUwvzuF5PdksImlModDH\nksoS3YEkImlNodDH0qoStu9r4cChjrhLERGJhUKhj6VVwXWF519XF5KIpKcop+PMMbM1ZrbezDaa\n2bcGaJNtZneY2RYzW21ms6OqZyROmFGMmZ5sFpH0FeWZQjtwhrsvBZYB55jZyf3aXAkccPf5wA+B\n70ZYz7AKc5LMqyhQKIhI2oosFDxweIS5ZPjyfs0uBG4Ol+8CzjQzi6qmkVgaXmx271+qiEjqi/Sa\ngpklzGwdUAfc7+6r+zWZAewAcPcuoAEoi7Km4SyrKmZvcwc7G9riLENEJBaRhoK7d7v7MqASWG5m\nx/drMtBZwRF/opvZVWZWY2Y19fX1UZTaa4lGTBWRNDYmdx+5+0HgEeCcfh/VAlUAZpYJFAP7B/j+\n9e5e7e7VFRUVkdZ6zLRCshIZCgURSUtR3n1UYWYl4XIu8D7gxX7NVgGXhcsXAQ95zJ352ZkJjp1e\npIfYRCQtRXmmMA142MyeB54luKbwGzO71swuCNvcCJSZ2Rbgy8BXI6xnxJZWFrOhtoHuHl1sFpH0\nkhnVjt39eeDEAbZ/s89yG3BxVDW8VUsrS7jl6e28Ut/MwimFcZcjIjJm9ETzAJZW6WKziKQnhcIA\n5pbnU5idqesKIpJ2FAoDyMgwTtD0nCKShhQKg1haVcKmXY20dWp6ThFJHwqFQSytLKarx9m0qzHu\nUkRExoxCYRC62Cwi6UihMIipRTlMKcqmZvuBuEsRERkzCoVBmBkr5pbxzNZ9GjFVRNKGQmEIp8wv\nZ29zB5v3NMVdiojImFAoDGHl/HIAnnh5b8yViIiMDYXCEGaU5DKnPJ+nXtkXdykiImNCoTCMU+aV\nsXrrPjq7e+IuRUQkcgqFYaycX86hjm6e15AXIpIGFArDWDG3DDN4cou6kEQk9SkUhjEpP4vF04p4\ncosuNotI6lMojMDK+eX84bWDtHZoHCQRSW1RTsdZZWYPm9kmM9toZtcM0OZ0M2sws3Xh65sD7Stu\np8wro6O7h2e3HTF9tIhISols5jWgC/iKu681s0LgOTO7391f6NfucXc/P8I63rblc0pJJownX9nL\nuxdWxF2OiEhkIjtTcPdd7r42XG4CNgEzovp5UcrLyuTEqkm6riAiKW9MrimY2WyC+ZpXD/DxCjNb\nb2a/M7PjBvn+VWZWY2Y19fX1EVY6uFPml7FxZyMHWzpi+fkiImMh8lAwswLgF8AX3b3/5ARrgVnu\nvhT4MfCrgfbh7te7e7W7V1dUxNN9c+r8ctzhaT3dLCIpLNJQMLMkQSDc6u6/7P+5uze6e3O4fA+Q\nNLPyKGt6q5ZWlZCfleDJV9SFJCKpK8q7jwy4Edjk7j8YpM3UsB1mtjysZ1z+KZ5MZLB8TilP6SE2\nEUlhUd59tBL4JLDBzNaF274OzARw958AFwGfM7MuoBW4xMfx5AUr55fz8OZN7GpoZVpxbtzliIiM\nushCwd2fAGyYNv8I/GNUNYy2U+YFPVtPbtnHRSdVxlyNiMjo0xPNR+GYqYWU5mfxlG5NFZEUpVA4\nChkZxop5ZTyxZa+m6BSRlKRQOEor55VT19TOK/XNcZciIjLqFApHaeX8MkBDaYtIalIoHKWZpXnM\nKMnVkBcikpIUCkfJzDh9UQWPv7xXQ2mLSMpRKLwF550wjdbObh7ZXBd3KSIio2pEoWBm15hZkQVu\nNLO1ZnZW1MWNV++aU0ppfha/3bAr7lJEREbVSM8UPh0OZncWUAFcAXwnsqrGucxEBmcfN5WHXqyj\nrVNdSCKSOkYaCoefTD4P+Dd3X88wTyunug+cMI2WDnUhiUhqGWkoPGdm9xGEwr3hTGo90ZU1/p08\n93AX0u64SxERGTUjHfvoSmAZsNXdW8yslKALKW0FXUhT+PW6nbR1dpOTTMRdkojI2zbSM4UVwGZ3\nP2hmnwC+ATREV9bEcF5vF1I8s8GJiIy2kYbCvwAtZrYU+EtgO3BLZFVNECvmljEpL8k9ugtJRFLE\nSEOhK5zn4ELgOne/DiiMrqyJ4fBdSA9u2qO7kEQkJYw0FJrM7GsEk+b81swSQHKoL5hZlZk9bGab\nzGyjmV0zQBszsx+Z2RYze97M3nH0hxCv806YxqGObh59SV1IIjLxjTQUPgq0EzyvsBuYAXx/mO90\nAV9x92OBk4GrzWxxvzbnAgvC11UE3VQTyop5ZZSoC0lEUsSIQiEMgluBYjM7H2hz9yGvKbj7Lndf\nGy43AZsIwqSvC4FbPPAMUGJm0472IOKUTGRw9uKpPLhJD7KJyMQ30mEuPgKsAS4GPgKsNrOLRvpD\nzGw2cCKwut9HM4AdfdZrOTI4xr3zlkyjub2Lx9SFJCIT3EifU/g/wDvdvQ7AzCqAB4C7hvuimRUA\nvwC+GA6V8aaPB/jKEVOamdlVBN1LzJw5c4Qlj51T+nQhnXXc1LjLERF5y0Z6TSHjcCCE9o3ku2aW\nJAiEW939lwM0qQWq+qxXAjv7N3L369292t2rKyoqRljy2EkmMjhr8RQeUBeSiExwIw2F35vZvWZ2\nuZldDvwWuGeoL5iZATcCm9z9B4M0WwV8KrwL6WSgwd0n5BXb804IupAef1mT74jIxDWi7iN3/wsz\n+zCwkqDL53p3v3uYr60kuIV1g5mtC7d9HZgZ7vMnBMFyHrAFaGECD52xcn45xblBF9L7F0+JuxwR\nkbdkpNcUcPdfEHQFjbT9Ewwzkmr4QNzVI93neJYMx0L63YbdtHZ0k5ulsZBEZOIZsvvIzJrMrHGA\nV5OZ9b9onPYuOqmKpvYufrXu9bhLERF5S4YMBXcvdPeiAV6F7l40VkVOFO+cPYljpxVx81PbCE6C\nREQmFs3RPIrMjMtPmcWLu5tY/er+uMsRETlqCoVRduGyGZTkJbnl6W1xlyIictQUCqMsJ5ngo9VV\n3LtxDzsPtsZdjojIUVEoROATJ8/C3bl19fa4SxEROSoKhQhUleZx5rFTuG3NDj3hLCITikIhIpef\nMpv9hzr4zfMT8gFtEUlTCoWInDKvjAWTC3R7qohMKAqFiJgZnzplNhteb2DtawfjLkdEZEQUChH6\n0IkzKMzO5OantsVdiojIiCgUIpSfnclF1ZXcs2EXdY1tcZcjIjIshULEPrViNl09zq2rX4u7FBGR\nYSkUIjanPJ/TF1Xw8zWv0dHVE3c5IiJDUiiMgStWzqG+qZ07a3YM31hEJEYKhTHw7gXlLJ9dyj88\n8DKH2rviLkdEZFCRhYKZ/czM6szsj4N8frqZNZjZuvD1zahqiZuZ8dXzjmFvczs/fXxr3OWIiAwq\nyjOFm4BzhmnzuLsvC1/XRlhL7N4xcxLnnTCV6x/bSl2T7kQSkfEpslBw98cATSrQx1+cfQwdXT1c\n98DLcZciIjKguK8prDCz9Wb2OzM7LuZaIjenPJ+Pv2smtz+7gy11zXGXIyJyhDhDYS0wy92XAj8G\nfjVYQzO7ysxqzKymvr5+zAqMwhfOXEBuMsH3fv9i3KWIiBwhtlBw90Z3bw6X7wGSZlY+SNvr3b3a\n3asrKirGtM7RVl6QzWffM5f7XthDzTb1ronI+BJbKJjZVDOzcHl5WMu+uOoZS58+dQ6TC7P523s2\naQRVERlXorwl9TbgaWCRmdWa2ZVm9lkz+2zY5CLgj2a2HvgRcImnyb+QeVmZfPn9C1n72kHu3bg7\n7nJERHrZRPt3uLq62mtqauIu423r6u7h3Osep6vHue9L7yaZiPuav4ikMjN7zt2rh2unf4likpnI\n4KvnHsOrew9x05Pb4i5HRARQKMTqjGMm8/7FU/j+fZvZvLsp7nJERBQKcTIz/u5DJ1CUk8kX71hH\ne1d33CWJSJpTKMSsvCCb7354CZt2NfKD+1+KuxwRSXMKhXHgzGOncOnymVz/2Fae2ZoWd+WKyDil\nUBgnvvGBY5lVmsdX7lxPY1tn3OWISJpSKIwT+dmZ/PCjy9jd2MZfr9oYdzkikqYUCuPIiTMncfV7\n5/PLta9zz4ZdcZcjImlIoTDOfOGM+SytLObrd29gT6PmXRCRsaVQGGeSiQx++NFltHf28Pmfr6Wt\nU7episjYUSiMQ3MrCvjeRUt4dtsBvnLnenp6JtZQJCIycWXGXYAM7E+WTmd3QxvfvmcT04pz+Mb5\ni+MuSUTSgEJhHPvMaXN4/WArNzzxKtNLcvn0qXPiLklEUpxCYRwzM/7v+YvZ1dDK//vtC0wrzuHc\nE6bFXZaIpDBdUxjnEhnGdZecyIlVJVxzxzrN1iYikVIoTAA5yQQ3XPZOKkty+cwtNWypa467JBFJ\nUVHOvPYzM6szsz8O8rmZ2Y/MbIuZPW9m74iqllRQmp/FTVcsJzPD+OSNq9lSp6G2RWT0RXmmcBNw\nzhCfnwssCF9XAf8SYS0pYWZZHrd8+l10djsX/eRp/vDagbhLEpEUE1kouPtjwFAd4BcCt3jgGaDE\nzHQVdRiLpxfxi8+toCgnycd+uppHX6qPuyQRSSFxXlOYAezos14bbpNhzCrL567PrWBOeT5X3vQs\nv173etwliUiKiDMUbIBtAz66a2ZXmVmNmdXU1+svY4DJhTnc/j9P5qRZk7jm9nX825Ovxl2SiKSA\nOEOhFqjqs14J7Byoobtf7+7V7l5dUVExJsVNBEU5SW7+9HLOPm4K3/qvF/je71/UkBgi8rbEGQqr\ngE+FdyGdDDS4u8aLPko5yQT//PGTuHR5Ff/8yCtccdOz7Gtuj7ssEZmgorwl9TbgaWCRmdWa2ZVm\n9lkz+2zY5B5gK7AF+CnwZ1HVkuoSGcbf/ukJ/M0Hj+fprfv4wI+eYM2reshNRI6euU+s7obq6mqv\nqamJu4xx64+vN/D5n69lx4FWvvz+hXzuPfPIyBjo8o2IpBMze87dq4drpyeaU8zxM4r5ry+cyrnH\nT+X7927mcnUnichRUCikoMKcJD++9ES+/afH88zWfZz3o8d5cNOeuMsSkQlAoZCizIyPv2sWd//Z\nKRTlJLny5ho+9x/PsbtBU3yKyOAUCinuuOnF/PbPT+Mvzl7EQy/W8b4fPMrNT22jW7euisgAFApp\nICszg6vfO5/7vvRuTpxZwl+t2siH/vlJNu5siLs0ERlnFAppZFZZPrd8ejnXXbKM1w+28ic/foJv\n/GoDdY3qUhKRgEIhzZgZFy6bwYNfPp1PnDyL29fs4N3ff5jv/f5FGlo74y5PRGKm5xTS3PZ9h/jB\n/S/x63U7Kc5N8rnT53HZitnkZiXiLk1ERtFIn1NQKAgAG3c28Pf3bubhzfVMLszm6vfO5yPVVQoH\nkRShUJC3ZPXWfXz/3s3UbD/ApLwkn1wxm8tWzKKsIDvu0kTkbVAoyFvm7jy77QDXP/YKD2yqIzsz\ng4tOquR/nDaX2eX5cZcnIm/BSEMhcyyKkYnFzFg+p5Tlc0rZUtfETx97lf+sqeXna17j/cdO4eMn\nz+K0+eUaU0kkBelMQUakrrGNm57axm1rXuNASyeVk3K5dPlMLj6pkslFOXGXJyLDUPeRRKK9q5t7\nN+7httWv8fTWfSQyjPcdO5lLl8/k1PnlZCZ0l7PIeKTuI4lEdmaCC5ZO54Kl09la38ztz+7grudq\nuXfjHsoLsjl/yTQ+eOIMllYWY6buJZGJRmcK8ra1d3Xz0KY6Vq3fyYMv1tHR1cOssjwuXDqdC0+c\nwbyKgrhLFEl746L7yMzOAa4DEsAN7v6dfp9fDnwfeD3c9I/ufsNQ+1QojG8NrZ3cu3E3q9bt5KlX\n9tLjsGhKIWcdN4WzFk/l+BlFOoMQiUHsoWBmCeAl4P1ALfAscKm7v9CnzeVAtbt/fqT7VShMHHWN\nbfx2wy7u3bibNa/up8dhenEOZx03lbMWT+Gdc0pJ6hqEyJgYD9cUlgNb3H1rWNDtwIXAC0N+S1LG\n5KIcrlg5hytWzmH/oQ4e3LSH+17Yw21rXuOmp7ZRmJ3JKfPLeM/CybxnUQUzSnLjLlkk7UUZCjOA\nHX3Wa4F3DdDuw2b2boKzii+5+47+DczsKuAqgJkzZ0ZQqkStND+Li6uruLi6ipaOLh5/eS+PbK7n\n0c113LsxmBVuweQC3rOwglMXlPPO2aXkZ+s+CJGxFmX30cXA2e7+mXD9k8Byd/9CnzZlQLO7t5vZ\nZ4GPuPsZQ+1X3Uepxd15pb45CIiX6ln96n46unrIzDCWVBazYl4ZK+aWc9KsSRqHSeRtGA/XFFYA\nf+3uZ4frXwNw978bpH0C2O/uxUPtV6GQ2lo7uqnZvp+nX9nH01v38XxtA909TjJhLKsq4aRZpVTP\nmsRJsyYxKT8r7nJFJozxcE3hWWCBmc0huLvoEuBjfRuY2TR33xWuXgBsirAemQBysxKctqCC0xZU\nANDc3sWz2/bzzCv7eObV/dz4xFZ+8mjwh8zcinyqZ02ielYpy2aWMK+igISG3hB5WyILBXfvMrPP\nA/cS3JL6M3ffaGbXAjXuvgr4czO7AOgC9gOXR1WPTEwF2Zm8d9Fk3rtoMgBtnd2s33GQmu0HWLv9\nAPe9sIc7a2oByM9KcPyMYpZWlbC0soQllcVUTsrVLbAiR0EPr8mE1tPjbN17iOdrD7J+x0HW1zbw\nws5GOrp7ACjJS7J4WhHHTS/iuOnFHDe9iLk6o5A0NB66j0Qil5FhzJ9cwPzJBXzoHZUAdHT1sHl3\nE+tqD/LCzgY27mzk5qe309EVBEVOMoNFUwpZOKWQRVMLOWZqEQunFlBRkK2zCkl7CgVJOVmZGZxQ\nWcwJlW/cs9DZ3cMr9c1sfL2RjTsb2bynkYc31/Gfz9X2tinNz2JBGDB9X1OLchQWkjYUCpIWkokM\njplaxDFTi/jwSW9s39vczku7m3hxdxObdzfxUl0Tq9bvpKmtq7dNQXYm8yrymVOez5zyAmaX5zE3\nfC/MScZwNCLRUShIWisvyKZ8fjanzC/v3ebu1De1s6WumS31zWypa+aV+mae3XaAX63becT3Z5fl\nMbM0j5llecwqy2NmaT4zS/MoL8jSGYZMOAoFkX7MjMlFOUwuynlTWEBw99P2fS28ureZrXsP8Wr9\nIV7b38LTW/fxyz+8/qa2uckElZNyw1de7/uMSblML86hvCBbs9fJuKNQEDkKOckEi6YGF6j7a+vs\npvZAC6/tb2H7vhZqD7RSeyB4X/vaQRpaO9/UPpkwphbnML04l+kluUwrzmFacQ5TinKYWhy8yvMV\nHDK2FAoioyQnmWD+5ELmTz4yMAAa2zqp3d/KzoOt7Gpo5fWDbexqCNbXvLqf3Y1tdPe8+RbxzAxj\ncmF2cOZSmM2UPu8VRdlUFGRTUZhNWX6WZr2TUaFQEBkjRTlJFk9Psnh60YCfd/c4+5rb2d3Yxu6G\ntje91ze1s23fIdZs28/Bls4jvmsGk/KyqCjIprwwi/KCbMrysykryKK8IKt3uSw/m0n5SQqyM3W9\nQwakUBAZJxIZb1zLWFI5eLu2zm7qm9qpa2qnvqmdvc3Be31zO3vD7Tv2H2RfczuHOroH3EdWIoNJ\n+Ukm5WVRmp/FpPwsJuUF6yV5fZeTlORlUZybpCgnU2cjaUChIDLB5CQTVJXmUVWaN2zb1o5u9h1q\nZ19zR+/7gZYO9h/q5MChDvYdCtY37WzkQEsHDa2d9AwxyEFhdibFeUmKc994FeUke7cV5WRSlJuk\nMCeTwpzgs2A5U2cnE4RCQSSF5WYlqMzKo3LS8AECwbAhjW2dHGjp5EBLBwfDoDjYErwaWjvD9Q4a\n27p4ua6ZxtZOGts6aevsGXLfZsEzH4XZmRSEoVFweDk7k/zsIDgObwvWE+RlBdvyszPJz0qQn51J\nbjKhC/ARUSiISK+MDKMk7EKaQ/5Rfbets5vGtk6a2rpoauuisTVYbmzrpLG1k+b2rt7PmtuD9QMt\nHezY30JzexfN7V20DNLdNZC8rCAw8sPgCNYT5CbD977bshLkJYP33KzM3jY5yaB9bvi9nGQGOckE\n2ZkZaXtWo1AQkVGRkwz+kR3k5qsR6e5xDnV0cai9i+a2Lg51dHOoPVg/1NFFc3uw3tLRTUt7Fy2d\nwfuhjm5aOoJgqW9qDz7vCNq1dnZztON+mkFO5hsh8cYrg5zMBNkDvYdhEryCttnhZ1mJjPA9WM/O\nzCArM9ieFbbPCr+blciI9SxIoSAi40YiwygKr0Uw5HRbI+futHX20NoZBERrRxetHT1BaHR2097Z\n/cbnHd20dXXT1tFNW1cPbb3bguW2zm7aO3vYf6iD9s6eoG34/Y6uYH00Bp7OzLAgNPoER1ZmBh9b\nPpPPnDb37f+AoX52pHsXEYmZmYXdRtFP5+rudHY77V3dtHf10B6GSUe4HLwPsd7dQ3tnD53dwWcd\nh9+7emjv7qG8IDvyY4g0FMzsHOA6gkl2bnD37/T7PBu4BTgJ2Ad81N23RVmTiEhUzIyszOCv/LfR\nixaryG46Dudc/ifgXGAxcKmZLe7X7ErggLvPB34IfDeqekREZHhRPomyHNji7lvdvQO4HbiwX5sL\ngZvD5buAMy1dL/mLiIwDUYbCDGBHn/XacNuAbdy9C2gAyiKsSUREhhBlKAz0F3//6/IjaYOZXWVm\nNWZWU19fPyrFiYjIkaIMhVqgqs96JbBzsDZmlklwE9r+/jty9+vdvdrdqysqKiIqV0REogyFZ4EF\nZjbHzLKAS4BV/dqsAi4Lly8CHnIfjbt8RUTkrYjsllR37zKzzwP3EtyS+jN332hm1wI17r4KuBH4\ndzPbQnCGcElU9YiIyPAifU7B3e8B7um37Zt9ltuAi6OsQURERs4mWm+NmdUD29/i18uBvaNYzkSS\nrseu404vOu7BzXL3YS/KTrhQeDvMrMbdq+OuIw7peuw67vSi4377NI2SiIj0UiiIiEivdAuF6+Mu\nIEbpeuw67vSi436b0uqagoiIDC3dzhRERGQIaRMKZnaOmW02sy1m9tW464mKmf3MzOrM7I99tpWa\n2f1m9nL4PinOGqNgZlVm9rC8OlAmAAAEv0lEQVSZbTKzjWZ2Tbg9pY/dzHLMbI2ZrQ+P+1vh9jlm\ntjo87jvCUQVSjpklzOwPZvabcD3lj9vMtpnZBjNbZ2Y14bZR+z1Pi1AY4dwOqeIm4Jx+274KPOju\nC4AHw/VU0wV8xd2PBU4Grg7/G6f6sbcDZ7j7UmAZcI6ZnUwwN8kPw+M+QDB3SSq6BtjUZz1djvu9\n7r6sz22oo/Z7nhahwMjmdkgJ7v4YRw4q2HfeipuBD45pUWPA3Xe5+9pwuYngH4oZpPixe6A5XE2G\nLwfOIJijBFLwuAHMrBL4AHBDuG6kwXEPYtR+z9MlFEYyt0Mqm+LuuyD4xxOYHHM9kTKz2cCJwGrS\n4NjDLpR1QB1wP/AKcDCcowRS9/f9H4C/BHrC9TLS47gduM/MnjOzq8Jto/Z7HunYR+PIiOZtkInP\nzAqAXwBfdPfGdJjIz927gWVmVgLcDRw7ULOxrSpaZnY+UOfuz5nZ6Yc3D9A0pY47tNLdd5rZZOB+\nM3txNHeeLmcKI5nbIZXtMbNpAOF7Xcz1RMLMkgSBcKu7/zLcnBbHDuDuB4FHCK6plIRzlEBq/r6v\nBC4ws20E3cFnEJw5pPpx4+47w/c6gj8CljOKv+fpEgojmdshlfWdt+Iy4Ncx1hKJsD/5RmCTu/+g\nz0cpfexmVhGeIWBmucD7CK6nPEwwRwmk4HG7+9fcvdLdZxP8//yQu3+cFD9uM8s3s8LDy8BZwB8Z\nxd/ztHl4zczOI/hL4vDcDt+OuaRImNltwOkEoybuAf4K+BVwJzATeA242N2PmOFuIjOzU4HHgQ28\n0cf8dYLrCil77Ga2hODCYoLgj7w73f1aM5tL8Bd0KfAH4BPu3h5fpdEJu4/+l7ufn+rHHR7f3eFq\nJvBzd/+2mZUxSr/naRMKIiIyvHTpPhIRkRFQKIiISC+FgoiI9FIoiIhIL4WCiIj0UiiIRMzMTj88\niqfIeKdQEBGRXgoFkZCZfSKcm2Cdmf1rONBcs5n9fzNba2YPmllF2HaZmT1jZs+b2d2Hx683s/lm\n9kA4v8FaM5sX7r7AzO4ysxfN7NbwCWzM7Dtm9kK4n7+P6dBFeikURAAzOxb4KMFgY8uAbuDjQD6w\n1t3fATxK8IQ4wC3A/3b3JQRPUR/efivwT+H8BqcAu8LtJwJfJJjPYy6w0sxKgT8Fjgv38zfRHqXI\n8BQKIoEzgZOAZ8NhqM8k+Me7B7gjbPMfwKlmVgyUuPuj4fabgXeHY9LMcPe7Ady9zd1bwjZr3L3W\n3XuAdcBsoBFoA24wsw8Bh9uKxEahIBIw4OZwNqtl7r7I3f96gHZDjQsz1Djdfcff6QYyw3H/lxOM\n7PpB4PdHWbPIqFMoiAQeBC4Kx6g/POftLIL/Rw6Puvkx4Al3bwAOmNlp4fZPAo+6eyNQa2YfDPeR\nbWZ5g/3AcO6HYne/h6BraVkUByZyNNJlkh2RIbn7C2b2DYIZrTKATuBq4BBwnJk9BzQQXHeAYHji\nn4T/6G8Frgi3fxL4VzO7NtzHxUP82ELg12aWQ3CW8aVRPiyRo6ZRUkWGYGbN7l4Qdx0iY0XdRyIi\n0ktnCiIi0ktnCiIi0kuhICIivRQKIiLSS6EgIiK9FAoiItJLoSAiIr3+G2K1EWaPiwnkAAAAAElF\nTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "loss_history = []\n", "\n", "weights = initialize_weights(n_units, order) \n", "biases = initialize_biases(n_units, order) \n", "\n", "\n", "epochs = 50\n", "\n", "for e in range(epochs):\n", " \n", " #Just records the error at the beginning of this step\n", " epoch_loss = []\n", " for xi in X:\n", " N = N_f(xi, weights, biases)[0]\n", " loss = loss_f(N, y_f(xi))\n", " epoch_loss.append(loss.constant_cf) \n", " \n", " loss_history.append(np.mean(epoch_loss))\n", " \n", " #Updates the weights\n", " for xi in X:\n", " N = N_f(xi, weights, biases)[0]\n", " loss = loss_f(N, y_f(xi))\n", " weights, biases = GD_update(loss, weights, biases, 0.001) \n", "\n", "print('epoch {0}, training loss: {1}'.format(e, loss_history[-1])) \n", "plt.plot(loss_history)\n", "plt.xlabel('epochs')\n", "plt.ylabel('loss')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 1 }