aboutsummaryrefslogtreecommitdiff
path: root/transformer_shortest_paths.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'transformer_shortest_paths.ipynb')
-rw-r--r--transformer_shortest_paths.ipynb947
1 files changed, 593 insertions, 354 deletions
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb
index ee844fb..1c6cdc6 100644
--- a/transformer_shortest_paths.ipynb
+++ b/transformer_shortest_paths.ipynb
@@ -11,7 +11,8 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 20,
+ "execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -19,53 +20,46 @@
"id": "ge5QvElvhCOw",
"outputId": "c7cdaefa-d6dc-44ad-c258-e4fb2aca97a5"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "imports complete\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "# imports\n",
- "import numpy as np\n",
"from collections import deque\n",
"import pickle\n",
"from tqdm import tqdm\n",
- "np.random.seed(42)\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import pickle\n",
"from math import sqrt\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
+ "%matplotlib widget\n",
"import matplotlib.pyplot as plt\n",
"torch.manual_seed(42)\n",
"\n",
"import os\n",
+ "from IPython.display import clear_output\n",
+ "import ipdb\n",
"\n",
- "print(\"imports complete\")"
+ "import random\n",
+ "random.seed(42)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 13,
+ "execution_state": "idle",
"metadata": {
"id": "lylOX2POPwFL"
},
"outputs": [],
"source": [
- "SEQ_LEN = 32\n",
- "\n",
+ "SEQ_LEN = 65 # means 32 edges, final token is the target vertex\n",
"PAD_TOKEN = 0\n",
"AVG_DEG = 2\n",
- "MAX_VTXS = SEQ_LEN//AVG_DEG - 1\n",
+ "MAX_VTXS = SEQ_LEN//AVG_DEG - 1 # 31\n",
+ "MIN_VTXS = 8\n",
+ "MAX_TUNE_VTXS = 16\n",
"# vertices are labelled 1,2,...,63\n",
- "# we also have a padding token which is 0.\n",
- "\n",
- "INF = MAX_VTXS # represents unreachability"
+ "# we also have a padding token which is 0."
]
},
{
@@ -79,7 +73,8 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
+ "execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -92,30 +87,26 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|██████████| 1/1 [00:14<00:00, 14.42s/it]\n"
+ "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:06<00:00, 3.52it/s]\n"
]
}
],
"source": [
"# original task data\n",
- "NTRAIN1 = 100_000\n",
+ "NTRAIN1 = 300_000\n",
"# the data will be edge lists\n",
- "# like this: [1 3 1 5 2 4 0 0 0 0]\n",
+ "# like this: [1 3 1 5 2 4 0 0 0 0 2]\n",
"# this represents edges (1,3), (1,5) (2,4)\n",
"# (the zeros are just padding tokens)\n",
+ "# the final 2 means which vertex we're going to \n",
"\n",
"# the label is the shortest distance from vtx 1 to vtx 2\n",
- "# or \"INF\" if no path exists\n",
- "\n",
- "# fine tuning data\n",
- "NTRAIN2 = 2000\n",
- "# I haven't totally figured out how to do the fine tuning yet.\n",
- "# So don't worry about this yet.\n",
+ "# or \"number of vertices\" if no path exists\n",
"\n",
"def random_graph(n):\n",
" edge_list = []\n",
" adjacencies = [set() for _ in range(n+1)]\n",
- " indices = np.random.randint(n, size=(AVG_DEG*(n-1)))+1\n",
+ " indices = [random.randint(1, n-1) for _ in range(AVG_DEG * (n-1))]\n",
" for i in range(0, len(indices), 2):\n",
" u = indices[i]\n",
" v = indices[i + 1]\n",
@@ -124,28 +115,23 @@
" adjacencies[u].add(v)\n",
" adjacencies[v].add(u)\n",
"\n",
- " if np.random.random() < 0.25:\n",
- " edge_list += [1,2]\n",
- " adjacencies[1].add(2)\n",
- " adjacencies[2].add(1)\n",
- "\n",
" edge_list += [PAD_TOKEN]*(SEQ_LEN-len(edge_list))\n",
" return edge_list, adjacencies\n",
"\n",
"\"\"\"\n",
"input: G, represented as an adjacency list\n",
- "output: [INF]+[d(1,i) for i in range(n)] if target=None\n",
+ "output: [number of vertices]+[d(1,i) for i in range(n)] if target=None\n",
"if target is set to some value, then we instead just output that specific distance\n",
"\"\"\"\n",
- "def SSSP(G, target=None):\n",
- " dist = [INF for _ in G]\n",
+ "def SSSP(n, G, target=2):\n",
+ " dist = [n for _ in G]\n",
" dist[1] = 0\n",
" frontier = deque()\n",
" frontier.append(1)\n",
" while len(frontier) > 0:\n",
" vtx = frontier.popleft()\n",
" for x in G[vtx]:\n",
- " if dist[x] == INF:\n",
+ " if dist[x] == n:\n",
" dist[x] = 1 + dist[vtx]\n",
" frontier.append(x)\n",
" if x == target:\n",
@@ -155,159 +141,190 @@
" else:\n",
" return dist\n",
"\n",
- "def fake_SSSP(G, target=None):\n",
- " return 2 in G[1]\n",
- "\n",
"graphs1 = []\n",
"distance1 = []\n",
"\n",
- "graphs2 = []\n",
- "distances2 = []\n",
- "\n",
- "for n in tqdm(range(MAX_VTXS-1, MAX_VTXS)):\n",
- " # for _ in range(NTRAIN1//MAX_VTXS):\n",
- " for _ in range(NTRAIN1):\n",
+ "for n in tqdm(range(MIN_VTXS, MAX_VTXS)):\n",
+ " for _ in range(NTRAIN1//(MAX_VTXS - MIN_VTXS)):\n",
" edge_list, adj_list = random_graph(n)\n",
- " dist = SSSP(adj_list, target=2)\n",
- "\n",
+ " dist = SSSP(n, adj_list)\n",
+ " edge_list[-1] = 2 # target token\n",
" graphs1.append(edge_list)\n",
" distance1.append(dist)\n",
"\n",
- "# for n in range(8, MAX_VTXS//4):\n",
- "# for _ in range(NTRAIN2//MAX_VTXS):\n",
- "# edge_list, adj_list = random_graph(n)\n",
- "# distances = SSSP(adj_list)\n",
- "# graphs2.append(edge_list)\n",
- "# distances2.append(distances)\n",
- "\n",
- "split1 = int(len(graphs1)*3/4)\n",
- "split2 = int(len(graphs2)*3/4)\n",
- "\n",
- "all1 = list(zip(graphs1, distance1))\n",
- "np.random.shuffle(all1)\n",
- "graphs1, distance1 = zip(*all1)\n",
- "\n",
"data = {\n",
- " \"train1-data\": graphs1[:split1],\n",
- " \"train1-labels\": distance1[:split1],\n",
- " \"test1-data\": graphs1[split1:],\n",
- " \"test1-labels\": distance1[split1:]\n",
- " # \"train2-data\": graphs2[:split2],\n",
- " # \"train2-labels\": distances2[:split2],\n",
- " # \"test2-data\": graphs2[split2:],\n",
- " # \"test2-labels\": distances2[split2:]\n",
+ " \"data\": torch.tensor(graphs1),\n",
+ " \"labels\": torch.tensor(distance1, dtype=torch.float32)\n",
"}\n",
"\n",
"with open('data.pkl', 'wb') as file:\n",
- " pickle.dump(data, file)\n",
- "\n"
+ " pickle.dump(data, file)"
]
},
{
"cell_type": "code",
- "execution_count": 74,
+ "execution_count": 16,
+ "execution_state": "idle",
"metadata": {},
+ "outputs": [],
+ "source": [
+ "def vertices_on_shortest_12_path(n, G, target=2):\n",
+ " dist = [n for _ in G]\n",
+ " parent = [-1 for _ in G]\n",
+ " dist[1] = 0\n",
+ " frontier = deque()\n",
+ " frontier.append(1)\n",
+ " while len(frontier) > 0:\n",
+ " vtx = frontier.popleft()\n",
+ " for x in G[vtx]:\n",
+ " if dist[x] == n:\n",
+ " parent[x] = vtx\n",
+ " dist[x] = 1 + dist[vtx]\n",
+ " frontier.append(x)\n",
+ " if x == target:\n",
+ " path = [x]\n",
+ " while parent[x] != -1:\n",
+ " x = parent[x]\n",
+ " path.append(x)\n",
+ " return list(reversed(path))\n",
+ " return []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "execution_state": "idle",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fine tuning data\n",
+ "NTRAIN2 = 2000\n",
+ "\n",
+ "graphs2 = []\n",
+ "distance2 = []\n",
+ "\n",
+ "for n in range(MIN_VTXS, MAX_TUNE_VTXS):\n",
+ " for _ in range(NTRAIN2//(MAX_TUNE_VTXS - MIN_VTXS)):\n",
+ " while True:\n",
+ " edge_list, adj_list = random_graph(n)\n",
+ " path = vertices_on_shortest_12_path(n, adj_list)\n",
+ " if len(path) > 1:\n",
+ " target_vtx_idx = random.randrange(1, len(path))\n",
+ " target_vtx = path[target_vtx_idx]\n",
+ " edge_list[-1] = target_vtx\n",
+ " graphs2.append(edge_list)\n",
+ " distance2.append(target_vtx_idx)\n",
+ " break\n",
+ "\n",
+ "tune_data = {\n",
+ " \"data\": torch.tensor(graphs2),\n",
+ " \"labels\": torch.tensor(distance2, dtype=torch.float32)\n",
+ "}\n",
+ "\n",
+ "with open('tune_data.pkl', 'wb') as file:\n",
+ " pickle.dump(tune_data, file)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "execution_state": "idle",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "EpDBxcgaIPpJ",
+ "outputId": "37cf9577-8cd8-444c-ec1a-c6f4b6061b7f"
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "torch.Size([75000, 32])\n",
- "DONE\n"
+ "pre-train dataset size = 149MB\n",
+ "fine-tune dataset = 1MB\n"
]
}
],
"source": [
- "NTRAIN1 = 100000\n",
- "\n",
- "graphs1 = torch.randint(1, MAX_VTXS, (NTRAIN1, SEQ_LEN))\n",
- "\n",
- "# check if token 1 is in the graph\n",
- "def silly_distance(graph):\n",
- " return int(1 in graph)\n",
- "\n",
- "# check if both token 1 and token 2 are in the graph\n",
- "def silly_distance2(graph):\n",
- " return int(1 in graph and 2 in graph and 3 in graph and 4 in graph and 5 in graph)\n",
- "\n",
- "def silly_distance3(graph):\n",
- " for i in range(len(graph)//2):\n",
- " if graph[2*i] + graph[2*i+1] == 3:\n",
- " return 1\n",
- " return 0\n",
- "\n",
- "distance1 = [silly_distance3(graph) for graph in graphs1]\n",
- "\n",
- "split1 = int(len(graphs1)*3/4)\n",
- "\n",
- "data = {\n",
- " \"train1-data\": graphs1[:split1],\n",
- " \"train1-labels\": distance1[:split1],\n",
- " \"test1-data\": graphs1[split1:],\n",
- " \"test1-labels\": distance1[split1:]\n",
- "}\n",
- "\n",
- "print(data[\"train1-data\"].shape)\n",
- "\n",
- "with open('data.pkl', 'wb') as file:\n",
- " pickle.dump(data, file)\n",
- "\n",
- "print(\"DONE\")\n"
+ "print(f\"pre-train dataset size = {os.path.getsize('data.pkl')//(1024*1024)}MB\")\n",
+ "print(f\"fine-tune dataset = {os.path.getsize('tune_data.pkl')//(1024*1024)}MB\")"
]
},
{
"cell_type": "code",
- "execution_count": 76,
- "metadata": {
- "scrolled": true
- },
+ "execution_count": 38,
+ "execution_state": "idle",
+ "metadata": {},
"outputs": [
{
"data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f5d5ab87fe4145eb8728e6d950e749d8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl3klEQVR4nO3df4xV9YH//xegM/iDGUSFkQJKa1dlFYwgMNutn1pZR0ubumKCrWmpUhvdgRRmq0jXgjZNaGh2q11UdmO29I9Sf2xWu8oKJVgxu079gUsKbiGtMYEGB+gPZpTKD5n7/WPL/TqVpfUHcwfej0dyE+ec9z33fd89zTxz5p5Lv0qlUgkAAMXoX+sJAADQuwQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhjqv1BI5m3d3d2bZtWwYNGpR+/frVejoAwJ+gUqnktddey/Dhw9O/f5nXwgTge7Bt27aMHDmy1tMAAN6FrVu3ZsSIEbWeRk0IwPdg0KBBye9PoIaGhlpPBwD4E3R1dWXkyJHV3+MlEoDvwcE/+zY0NAhAADjKlPzxrTL/8A0AUDABCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQmONqPQGOgDsaD7OvszdnAgD0Qa4AAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAUpqYB+M1vfjP9+vXLnDlzqtv27NmT1tbWnHrqqTn55JMzbdq0bN++vcfztmzZkqlTp+bEE0/M0KFDc8stt+TNN9/sMeapp57KRRddlPr6+px99tlZtmzZ217/nnvuyVlnnZWBAwdm0qRJee65547guwUA6BtqFoDPP/98/umf/iljx47tsX3u3Ll57LHH8vDDD2ft2rXZtm1brr766ur+AwcOZOrUqdm3b1+eeeaZfO9738uyZcuyYMGC6phXXnklU6dOzaWXXpr169dnzpw5+eIXv5hVq1ZVxzz44INpa2vLwoUL8+KLL2bcuHFpaWnJjh07emkFAABqo1+lUqn09ou+/vrrueiii3LvvffmG9/4Ri688MLcdddd6ezszOmnn57ly5fnmmuuSZJs2rQp5513Xtrb2zN58uQ88cQT+eQnP5lt27Zl2LBhSZKlS5dm3rx52blzZ+rq6jJv3rysWLEiGzdurL7mtddem127dmXlypVJkkmTJuXiiy/OkiVLkiTd3d0ZOXJkZs+endtuu+1Peh9dXV1pbGxMZ2dnGhoajsBKvUt3NB5mX2dvzgQA+pw++/u7F9XkCmBra2umTp2aKVOm9Ni+bt267N+/v8f2c889N6NGjUp7e3uSpL29PRdccEE1/pKkpaUlXV1deemll6pj/vDYLS0t1WPs27cv69at6zGmf//+mTJlSnXMoezduzddXV09HgAAR5vjevsFH3jggbz44ot5/vnn37avo6MjdXV1GTx4cI/tw4YNS0dHR3XMW+Pv4P6D+w43pqurK2+88UZ++9vf5sCBA4ccs2nTpv9z7osWLcqdd975jt8zAEBf0qtXALdu3Zovf/nL+f73v5+BAwf25ku/L+bPn5/Ozs7qY+vWrbWeEgDAO9arVwDXrVuXHTt25KKLLqpuO3DgQJ5++uksWbIkq1atyr59+7Jr164eVwG3b9+epqamJElTU9Pb7tY9eJfwW8f84Z3D27dvT0NDQ0444YQMGDAgAwYMOOSYg8c4lPr6+tTX17+nNejTfHYQAIrQq1cAL7vssmzYsCHr16+vPiZMmJDrrruu+t/HH3981qxZU33O5s2bs2XLljQ3NydJmpubs2HDhh53665evToNDQ0ZM2ZMdcxbj3FwzMFj1NXVZfz48T3GdHd3Z82aNdUxAADHql69Ajho0KCcf/75PbaddNJJOfXUU6vbZ86cmba2tgwZMiQNDQ2ZPXt2mpubM3ny5CTJ5ZdfnjFjxuRzn/tcFi9enI6Ojtx+++1pbW2tXp276aabsmTJktx666254YYb8uSTT+ahhx7KihUrqq/b1taWGTNmZMKECZk4cWLuuuuu7N69O9dff31vLgkAQK/r9ZtA/phvf/vb6d+/f6ZNm5a9e/empaUl9957b3X/gAED8vjjj+fmm29Oc3NzTjrppMyYMSNf//rXq2NGjx6dFStWZO7cubn77rszYsSI3H///WlpaamOmT59enbu3JkFCxako6MjF154YVauXPm2G0MAAI41NfkewGNFn/0eoXf7WT6fAQSgAH3293cv8m8BAwAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFOa4Wk+Aw7ij8TD7OntzJgDAMcQVQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgML0egDed999GTt2bBoaGtLQ0JDm5uY88cQT1f179uxJa2trTj311Jx88smZNm1atm/f3uMYW7ZsydSpU3PiiSdm6NChueWWW/Lmm2/2GPPUU0/loosuSn19fc4+++wsW7bsbXO55557ctZZZ2XgwIGZNGlSnnvuuSP4zgEA+oZeD8ARI0bkm9/8ZtatW5cXXnghH//4x/PpT386L730UpJk7ty5eeyxx/Lwww9n7dq12bZtW66++urq8w8cOJCpU6dm3759eeaZZ/K9730vy5Yty4IFC6pjXnnllUydOjWXXnpp1q9fnzlz5uSLX/xiVq1aVR3z4IMPpq2tLQsXLsyLL76YcePGpaWlJTt27OjlFQEA6F39KpVKpdaTGDJkSL71rW/lmmuuyemnn57ly5fnmmuuSZJs2rQp5513Xtrb2zN58uQ88cQT+eQnP5lt27Zl2LBhSZKlS5dm3rx52blzZ+rq6jJv3rysWLEiGzdurL7Gtddem127dmXlypVJkkmTJuXiiy/OkiVLkiTd3d0ZOXJkZs+endtuu+1PmndXV1caGxvT2dmZhoaG939h7mg8zL7OvvM8ADiKHPHf30eBmn4G8MCBA3nggQeye/fuNDc3Z926ddm/f3+mTJlSHXPuuedm1KhRaW9vT5K0t7fnggsuqMZfkrS0tKSrq6t6FbG9vb3HMQ6OOXiMffv2Zd26dT3G9O/fP1OmTKmOOZS9e/emq6urxwMA4GhTkwDcsGFDTj755NTX1+emm27KI488kjFjxqSjoyN1dXUZPHhwj/HDhg1LR0dHkqSjo6NH/B3cf3Df4cZ0dXXljTfeyK9+9ascOHDgkGMOHuNQFi1alMbGxupj5MiR73ElAAB6X00C8Jxzzsn69evz7LPP5uabb86MGTPyP//zP7WYyjsyf/78dHZ2Vh9bt26t9ZQAAN6x42rxonV1dTn77LOTJOPHj8/zzz+fu+++O9OnT8++ffuya9euHlcBt2/fnqampiRJU1PT2+7WPXiX8FvH/OGdw9u3b09DQ0NOOOGEDBgwIAMGDDjkmIPHOJT6+vrU19e/5/cPAFBLfeJ7ALu7u7N3796MHz8+xx9/fNasWVPdt3nz5mzZsiXNzc1Jkubm5mzYsKHH3bqrV69OQ0NDxowZUx3z1mMcHHPwGHV1dRk/fnyPMd3d3VmzZk11DADAsarXrwDOnz8/V155ZUaNGpXXXnsty5cvz1NPPZVVq1alsbExM2fOTFtbW4YMGZKGhobMnj07zc3NmTx5cpLk8ssvz5gxY/K5z30uixcvTkdHR26//fa0trZWr87ddNNNWbJkSW699dbccMMNefLJJ/PQQw9lxYoV1Xm0tbVlxowZmTBhQiZOnJi77roru3fvzvXXX9/bSwIA0Kt6PQB37NiRz3/+83n11VfT2NiYsWPHZtWqVfmrv/qrJMm3v/3t9O/fP9OmTcvevXvT0tKSe++9t/r8AQMG5PHHH8/NN9+c5ubmnHTSSZkxY0a+/vWvV8eMHj06K1asyNy5c3P33XdnxIgRuf/++9PS0lIdM3369OzcuTMLFixIR0dHLrzwwqxcufJtN4YAABxr+sT3AB6tfA8gABx9fA9gH/kMIAAAvUcAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAU5rhaT4BjgH9DGACOKq4AAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAUptcDcNGiRbn44oszaNCgDB06NFdddVU2b97cY8yePXvS2tqaU089NSeffHKmTZuW7du39xizZcuWTJ06NSeeeGKGDh2aW265JW+++WaPMU899VQuuuii1NfX5+yzz86yZcveNp977rknZ511VgYOHJhJkyblueeeO0LvHACgb+j1AFy7dm1aW1vzk5/8JKtXr87+/ftz+eWXZ/fu3dUxc+fOzWOPPZaHH344a9euzbZt23L11VdX9x84cCBTp07Nvn378swzz+R73/teli1blgULFlTHvPLKK5k6dWouvfTSrF+/PnPmzMkXv/jFrFq1qjrmwQcfTFtbWxYuXJgXX3wx48aNS0tLS3bs2NGLKwIA0Lv6VSqVSi0nsHPnzgwdOjRr167NJZdcks7Ozpx++ulZvnx5rrnmmiTJpk2bct5556W9vT2TJ0/OE088kU9+8pPZtm1bhg0bliRZunRp5s2bl507d6auri7z5s3LihUrsnHjxuprXXvttdm1a1dWrlyZJJk0aVIuvvjiLFmyJEnS3d2dkSNHZvbs2bntttv+6Ny7urrS2NiYzs7ONDQ0vP+Lc0fjYfZ1Hv3PA4AaOOK/v48CNf8MYGfn/wbCkCFDkiTr1q3L/v37M2XKlOqYc889N6NGjUp7e3uSpL29PRdccEE1/pKkpaUlXV1deemll6pj3nqMg2MOHmPfvn1Zt25djzH9+/fPlClTqmMAAI5Fx9Xyxbu7uzNnzpx85CMfyfnnn58k6ejoSF1dXQYPHtxj7LBhw9LR0VEd89b4O7j/4L7Djenq6sobb7yR3/72tzlw4MAhx2zatOmQ8927d2/27t1b/bmrq+s9vHsAgNqo6RXA1tbWbNy4MQ888EAtp/EnW7RoURobG6uPkSNH1npKAADvWM0CcNasWXn88cfz4x//OCNGjKhub2pqyr59+7Jr164e47dv356mpqbqmD+8K/jgz39sTENDQ0444YScdtppGTBgwCHHHDzGH5o/f346Ozurj61bt76nNQAAqIVeD8BKpZJZs2blkUceyZNPPpnRo0f32D9+/Pgcf/zxWbNmTXXb5s2bs2XLljQ3NydJmpubs2HDhh53665evToNDQ0ZM2ZMdcxbj3FwzMFj1NXVZfz48T3GdHd3Z82aNdUxf6i+vj4NDQ09HgAAR5te/wxga2trli9fnh/+8IcZNGhQ9TN7jY2NOeGEE9LY2JiZM2emra0tQ4YMSUNDQ2bPnp3m5uZMnjw5SXL55ZdnzJgx+dznPpfFixeno6Mjt99+e1pbW1NfX58kuemmm7JkyZLceuutueGGG/Lkk0/moYceyooVK6pzaWtry4wZMzJhwoRMnDgxd911V3bv3p3rr7++t5cFAKDX9HoA3nfffUmSj33sYz22f/e7380XvvCFJMm3v/3t9O/fP9OmTcvevXvT0tKSe++9tzp2wIABefzxx3PzzTenubk5J510UmbMmJGvf/3r1TGjR4/OihUrMnfu3Nx9990ZMWJE7r///rS0tFTHTJ8+PTt37syCBQvS0dGRCy+8MCtXrnzbjSEAAMeSmn8P4NHM9wC+x+cBQA34HsA+8D2AAAD0LgEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUJjjaj0BCnZH42H2dfbmTACgKK4AAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFKbXA/Dpp5/Opz71qQwfPjz9+vXLo48+2mN/pVLJggULcsYZZ+SEE07IlClT8vOf/7zHmN/85je57rrr0tDQkMGDB2fmzJl5/fXXe4z56U9/mo9+9KMZOHBgRo4cmcWLF79tLg8//HDOPffcDBw4MBdccEH+4z/+4wi9awCAvqPXA3D37t0ZN25c7rnnnkPuX7x4cb7zne9k6dKlefbZZ3PSSSelpaUle/bsqY657rrr8tJLL2X16tV5/PHH8/TTT+dLX/pSdX9XV1cuv/zynHnmmVm3bl2+9a1v5Y477sg///M/V8c888wz+cxnPpOZM2fmv//7v3PVVVflqquuysaNG4/wCgAA1Fa/SqVSqdmL9+uXRx55JFdddVXy+6t/w4cPz9/+7d/mK1/5SpKks7Mzw4YNy7Jly3LttdfmZz/7WcaMGZPnn38+EyZMSJKsXLkyn/jEJ/LLX/4yw4cPz3333Ze/+7u/S0dHR+rq6pIkt912Wx599NFs2rQpSTJ9+vTs3r07jz/+eHU+kydPzoUXXpilS5f+SfPv6upKY2NjOjs709DQ8L6vT+5oPMy+znKfBwDvwRH//X0U6FOfAXzllVfS0dGRKVOmVLc1NjZm0qRJaW9vT5K0t7dn8ODB1fhLkilTpqR///559tlnq2MuueSSavwlSUtLSzZv3pzf/va31TFvfZ2DYw6+DgDAseq4Wk/grTo6OpIkw4YN67F92LBh1X0dHR0ZOnRoj/3HHXdchgwZ0mPM6NGj33aMg/tOOeWUdHR0HPZ1DmXv3r3Zu3dv9eeurq53+U4BAGqnT10B7OsWLVqUxsbG6mPkyJG1nhIAwDvWpwKwqakpSbJ9+/Ye27dv317d19TUlB07dvTY/+abb+Y3v/lNjzGHOsZbX+P/GnNw/6HMnz8/nZ2d1cfWrVvfw7sFAKiNPhWAo0ePTlNTU9asWVPd1tXVlWeffTbNzc1Jkubm5uzatSvr1q2rjnnyySfT3d2dSZMmVcc8/fTT2b9/f3XM6tWrc8455+SUU06pjnnr6xwcc/B1DqW+vj4NDQ09HgAAR5teD8DXX38969evz/r165Pf3/ixfv36bNmyJf369cucOXPyjW98I//+7/+eDRs25POf/3yGDx9evVP4vPPOyxVXXJEbb7wxzz33XP7rv/4rs2bNyrXXXpvhw4cnST772c+mrq4uM2fOzEsvvZQHH3wwd999d9ra2qrz+PKXv5yVK1fm7//+77Np06bccccdeeGFFzJr1qzeXhIAgF7V6zeBvPDCC7n00kurPx+MshkzZmTZsmW59dZbs3v37nzpS1/Krl278pd/+ZdZuXJlBg4cWH3O97///cyaNSuXXXZZ+vfvn2nTpuU73/lOdX9jY2N+9KMfpbW1NePHj89pp52WBQsW9PiuwL/4i7/I8uXLc/vtt+erX/1qPvzhD+fRRx/N+eef32trAQBQCzX9HsCjne8BrNHzAOA98D2AfewzgAAAHHkCEACgMAIQAKAwAhAAoDACEACgMH3q3wKG98ydxQDwR7kCCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQmONqPQE4qt3ReJh9nb05EwD4k7kCCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBhfAwO14OtjAKghVwABAAojAAEACuNPwHA08adjAN4HrgACABTGFUDg3V9Z7O0rksfCFdCjZa1Ldyys97HwHjhiBCDQ+46FcDxafrn2teDs7fkcLf87Hc6x8P+XvvR6JAIQAN5HJcdMye/9KOQzgAAAhRGAAACFEYAAAIURgAAAhRGAAACFcRcwxN1rAJTFFUAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIUH4D33HNPzjrrrAwcODCTJk3Kc889V+spAQAcUUUH4IMPPpi2trYsXLgwL774YsaNG5eWlpbs2LGj1lMDADhiig7Af/iHf8iNN96Y66+/PmPGjMnSpUtz4okn5l/+5V9qPTUAgCPmuFpPoFb27duXdevWZf78+dVt/fv3z5QpU9Le3n7I5+zduzd79+6t/tzZ2Zkk6erqOjKT3Fv5v/cd7jWP9ecdztHyHvrSmtViPkfL8w7naHkPfWnNajGfo+V5h3O0vIe+tGZ/xMHf25XKYV77GNevUui737ZtWz7wgQ/kmWeeSXNzc3X7rbfemrVr1+bZZ59923PuuOOO3Hnnnb08UwDgSNi6dWtGjBhR62nURLFXAN+N+fPnp62trfpzd3d3fvOb3+TUU09Nv379eozt6urKyJEjs3Xr1jQ0NNRgtkcn6/bOWbN3x7q9O9bt3bFu79yRXLNKpZLXXnstw4cPf1+PezQpNgBPO+20DBgwINu3b++xffv27Wlqajrkc+rr61NfX99j2+DBgw/7Og0NDf7P/i5Yt3fOmr071u3dsW7vjnV7547UmjU2Nr7vxzyaFHsTSF1dXcaPH581a9ZUt3V3d2fNmjU9/iQMAHCsKfYKYJK0tbVlxowZmTBhQiZOnJi77roru3fvzvXXX1/rqQEAHDFFB+D06dOzc+fOLFiwIB0dHbnwwguzcuXKDBs27D0fu76+PgsXLnzbn4w5POv2zlmzd8e6vTvW7d2xbu+cNTuyir0LGACgVMV+BhAAoFQCEACgMAIQAKAwAhAAoDAC8Ai45557ctZZZ2XgwIGZNGlSnnvuuVpPqU+744470q9fvx6Pc889t9bT6nOefvrpfOpTn8rw4cPTr1+/PProoz32VyqVLFiwIGeccUZOOOGETJkyJT//+c9rNt++4o+t2xe+8IW3nX9XXHFFzebbFyxatCgXX3xxBg0alKFDh+aqq67K5s2be4zZs2dPWltbc+qpp+bkk0/OtGnT3vbF+qX5U9btYx/72NvOt5tuuqlmc+4L7rvvvowdO7b6hc/Nzc154oknqvuda0eGAHyfPfjgg2lra8vChQvz4osvZty4cWlpacmOHTtqPbU+7c///M/z6quvVh//+Z//Wesp9Tm7d+/OuHHjcs899xxy/+LFi/Od73wnS5cuzbPPPpuTTjopLS0t2bNnT6/PtS/5Y+uWJFdccUWP8+8HP/hBr86xr1m7dm1aW1vzk5/8JKtXr87+/ftz+eWXZ/fu3dUxc+fOzWOPPZaHH344a9euzbZt23L11VfXdN619qesW5LceOONPc63xYsX12zOfcGIESPyzW9+M+vWrcsLL7yQj3/84/n0pz+dl156KXGuHTkV3lcTJ06stLa2Vn8+cOBAZfjw4ZVFixbVdF592cKFCyvjxo2r9TSOKkkqjzzySPXn7u7uSlNTU+Vb3/pWdduuXbsq9fX1lR/84Ac1mmXf84frVqlUKjNmzKh8+tOfrtmcjgY7duyoJKmsXbu2Uvn9uXX88cdXHn744eqYn/3sZ5Uklfb29hrOtG/5w3WrVCqV//f//l/ly1/+ck3ndTQ45ZRTKvfff79z7QhyBfB9tG/fvqxbty5Tpkypbuvfv3+mTJmS9vb2ms6tr/v5z3+e4cOH54Mf/GCuu+66bNmypdZTOqq88sor6ejo6HHuNTY2ZtKkSc69P8FTTz2VoUOH5pxzzsnNN9+cX//617WeUp/S2dmZJBkyZEiSZN26ddm/f3+P8+3cc8/NqFGjnG9v8YfrdtD3v//9nHbaaTn//PMzf/78/O53v6vRDPueAwcO5IEHHsju3bvT3NzsXDuCiv6XQN5vv/rVr3LgwIG3/Usiw4YNy6ZNm2o2r75u0qRJWbZsWc4555y8+uqrufPOO/PRj340GzduzKBBg2o9vaNCR0dH8vtz7a2GDRtW3cehXXHFFbn66qszevTovPzyy/nqV7+aK6+8Mu3t7RkwYECtp1dz3d3dmTNnTj7ykY/k/PPPT35/vtXV1WXw4ME9xjrf/n+HWrck+exnP5szzzwzw4cPz09/+tPMmzcvmzdvzr/927/VdL61tmHDhjQ3N2fPnj05+eST88gjj2TMmDFZv369c+0IEYDU3JVXXln977Fjx2bSpEk588wz89BDD2XmzJk1nRvHvmuvvbb63xdccEHGjh2bD33oQ3nqqady2WWX1XRufUFra2s2btzoc7nv0P+1bl/60peq/33BBRfkjDPOyGWXXZaXX345H/rQh2ow077hnHPOyfr169PZ2Zl//dd/zYwZM7J27dpaT+uY5k/A76PTTjstAwYMeNvdSdu3b09TU1PN5nW0GTx4cP7sz/4sv/jFL2o9laPGwfPLuffeffCDH8xpp53m/Esya9asPP744/nxj3+cESNGVLc3NTVl37592bVrV4/xzrf/9X+t26FMmjQpSYo/3+rq6nL22Wdn/PjxWbRoUcaNG5e7777buXYECcD3UV1dXcaPH581a9ZUt3V3d2fNmjVpbm6u6dyOJq+//npefvnlnHHGGbWeylFj9OjRaWpq6nHudXV15dlnn3XuvUO//OUv8+tf/7ro869SqWTWrFl55JFH8uSTT2b06NE99o8fPz7HH398j/Nt8+bN2bJlS9Hn2x9bt0NZv359khR9vh1Kd3d39u7d61w7gvwJ+H3W1taWGTNmZMKECZk4cWLuuuuu7N69O9dff32tp9ZnfeUrX8mnPvWpnHnmmdm2bVsWLlyYAQMG5DOf+Uytp9anvP766z2uErzyyitZv359hgwZklGjRmXOnDn5xje+kQ9/+MMZPXp0vva1r2X48OG56qqrajrvWjvcug0ZMiR33nlnpk2blqamprz88su59dZbc/bZZ6elpaWm866l1tbWLF++PD/84Q8zaNCg6metGhsbc8IJJ6SxsTEzZ85MW1tbhgwZkoaGhsyePTvNzc2ZPHlyradfM39s3V5++eUsX748n/jEJ3Lqqafmpz/9aebOnZtLLrkkY8eOrfX0a2b+/Pm58sorM2rUqLz22mtZvnx5nnrqqaxatcq5diTV+jbkY9E//uM/VkaNGlWpq6urTJw4sfKTn/yk1lPq06ZPn14544wzKnV1dZUPfOADlenTp1d+8Ytf1Hpafc6Pf/zjSpK3PWbMmFGp/P6rYL72ta9Vhg0bVqmvr69cdtlllc2bN9d62jV3uHX73e9+V7n88ssrp59+euX444+vnHnmmZUbb7yx0tHRUetp19Sh1itJ5bvf/W51zBtvvFH5m7/5m8opp5xSOfHEEyt//dd/XXn11VdrOu9a+2PrtmXLlsoll1xSGTJkSKW+vr5y9tlnV2655ZZKZ2dnradeUzfccEPlzDPPrNTV1VVOP/30ymWXXVb50Y9+VN3vXDsy+lX+96QFAKAQPgMIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFCY/w8tmfmOTDaPEwAAAABJRU5ErkJggg==",
+ "text/html": [
+ "\n",
+ " <div style=\"display: inline-block;\">\n",
+ " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
+ " Figure 2\n",
+ " </div>\n",
+ " <img src='' width=640.0/>\n",
+ " </div>\n",
+ " "
+ ],
"text/plain": [
- "0.1518"
- ]
- },
- "execution_count": 76,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "text/plain": [
- "0.1518"
+ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
]
},
- "execution_count": 75,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "sum(distance1)/len(distance1)"
+ "with plt.ioff():\n",
+ " plt.hist(data['labels'],bins=64)\n",
+ " plt.show()"
]
},
{
"cell_type": "code",
- "execution_count": 77,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "EpDBxcgaIPpJ",
- "outputId": "37cf9577-8cd8-444c-ec1a-c6f4b6061b7f"
- },
+ "execution_count": 50,
+ "execution_state": "idle",
+ "metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "dataset size = 49MB\n"
- ]
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "24b2976d050e43af8bad0e4080a224eb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjcElEQVR4nO3dfZTWdZ3/8dcAcRMxg1DMOCsgtW6IN2Xi0mg3W7KSsZ48cSp2J5eUI3vaoUTSgkrKvAHZXTOKID0ucE66VrurFRVKWLIlIlK03oW2WVDuwO4hZoSOgzLX749fXqcx/W2/1rkuxs/jcc51Ttf3+7nm+/4QytPvXHPRUKlUKgEAoBiD6j0AAAC1JQABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAozpN4DDGS9vb15/PHHM2rUqDQ0NNR7HADg91CpVPLEE0+ktbU1gwaVeS9MAP4vPP744xk/fny9xwAA/gC7d+/OMcccU+8x6kIA/i+MGjUq+c1voMbGxnqPAwD8Hrq7uzN+/Pjqn+MlEoD/C89827exsVEAAsAAU/Lbt8r8xjcAQMEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYYbUewCe37GLvvG85362bGZNZwEAXjzcAQQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAozBEZgJs3b84555yT1tbWNDQ05LbbbutzvlKpZMmSJTn66KMzYsSITJ8+PY8++mifNfv27Ut7e3saGxszevTozJ07NwcOHOiz5t///d/zxje+McOHD8/48eOzfPnymuwPAKCejsgAPHjwYF7zmtdk5cqVz3l++fLlWbFiRVavXp2tW7dm5MiRmTFjRp588snqmvb29jz44IPZuHFj1q9fn82bN2fevHnV893d3TnrrLMyceLEbN++PX/3d3+XT37yk7n++utrskcAgHppqFQqlXoP8f/S0NCQW2+9Neeee27ym7t/ra2t+dCHPpRLLrkkSdLV1ZXm5uasXbs2s2fPzsMPP5wpU6Zk27ZtmTp1apJkw4YNefvb355f/OIXaW1tzapVq/Kxj30snZ2dGTp0aJJk0aJFue222/LjH//495qtu7s7TU1N6erqSmNj4wu+92MXfeN5z/1s2cwX/HoAUIL+/vN7IDgi7wD+vzz22GPp7OzM9OnTq8eampoybdq0bNmyJUmyZcuWjB49uhp/STJ9+vQMGjQoW7dura5505veVI2/JJkxY0Z27tyZX/3qV8957Z6ennR3d/d5AAAMNAMuADs7O5Mkzc3NfY43NzdXz3V2dmbcuHF9zg8ZMiRjxozps+a5vsZvX+PZli5dmqampupj/PjxL+DOAABqY8AFYD0tXrw4XV1d1cfu3bvrPRIAwP+3AReALS0tSZI9e/b0Ob5nz57quZaWluzdu7fP+aeffjr79u3rs+a5vsZvX+PZhg0blsbGxj4PAICBZsAF4KRJk9LS0pJNmzZVj3V3d2fr1q1pa2tLkrS1tWX//v3Zvn17dc2dd96Z3t7eTJs2rbpm8+bNeeqpp6prNm7cmFe/+tU56qijaronAIBaOiID8MCBA9mxY0d27NiR/OYHP3bs2JFdu3aloaEhCxYsyJVXXpmvfe1ruf/++/PXf/3XaW1trf6k8PHHH5+3ve1tufDCC3Pvvffm+9//fubPn5/Zs2entbU1SfJXf/VXGTp0aObOnZsHH3wwX/rSl/KZz3wmCxcurOveAQD625B6D/Bc7rvvvrzlLW+pPn8myubMmZO1a9fmwx/+cA4ePJh58+Zl//79ecMb3pANGzZk+PDh1dfcdNNNmT9/fs4888wMGjQos2bNyooVK6rnm5qacscdd6SjoyOnnnpqXv7yl2fJkiV9PisQAODF6Ij/HMAjmc8BBICBx+cAHqHfAgYAoP8IQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCDMgAPHz4cC677LJMmjQpI0aMyKte9apcccUVqVQq1TWVSiVLlizJ0UcfnREjRmT69Ol59NFH+3ydffv2pb29PY2NjRk9enTmzp2bAwcO1GFHAAC1MyAD8JprrsmqVavyuc99Lg8//HCuueaaLF++PJ/97Gera5YvX54VK1Zk9erV2bp1a0aOHJkZM2bkySefrK5pb2/Pgw8+mI0bN2b9+vXZvHlz5s2bV6ddAQDURkPlt2+bDRB/8Rd/kebm5tx4443VY7NmzcqIESPyxS9+MZVKJa2trfnQhz6USy65JEnS1dWV5ubmrF27NrNnz87DDz+cKVOmZNu2bZk6dWqSZMOGDXn729+eX/ziF2ltbf0f5+ju7k5TU1O6urrS2Nj4gu/z2EXfeN5zP1s28wW/HgCUoL///B4IBuQdwNNPPz2bNm3KI488kiT50Y9+lO9973s5++yzkySPPfZYOjs7M3369OprmpqaMm3atGzZsiVJsmXLlowePboaf0kyffr0DBo0KFu3bq35ngAAamVIvQf4QyxatCjd3d2ZPHlyBg8enMOHD+eqq65Ke3t7kqSzszNJ0tzc3Od1zc3N1XOdnZ0ZN25cn/NDhgzJmDFjqmueraenJz09PdXn3d3dL/jeAAD624C8A/jlL385N910U26++eb84Ac/yLp16/L3f//3WbduXb9ed+nSpWlqaqo+xo8f36/XAwDoDwMyAC+99NIsWrQos2fPzkknnZTzzjsvF198cZYuXZokaWlpSZLs2bOnz+v27NlTPdfS0pK9e/f2Of/0009n37591TXPtnjx4nR1dVUfu3fv7qcdAgD0nwEZgL/+9a8zaFDf0QcPHpze3t4kyaRJk9LS0pJNmzZVz3d3d2fr1q1pa2tLkrS1tWX//v3Zvn17dc2dd96Z3t7eTJs27TmvO2zYsDQ2NvZ5AAAMNAPyPYDnnHNOrrrqqkyYMCEnnHBCfvjDH+baa6/NBRdckCRpaGjIggULcuWVV+a4447LpEmTctlll6W1tTXnnntukuT444/P2972tlx44YVZvXp1nnrqqcyfPz+zZ8/+vX4CGABgoBqQAfjZz342l112Wf72b/82e/fuTWtra/7mb/4mS5Ysqa758Ic/nIMHD2bevHnZv39/3vCGN2TDhg0ZPnx4dc1NN92U+fPn58wzz8ygQYMya9asrFixok67AgCojQH5OYBHCp8DCAADj88BHKDvAQQA4A8nAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAozYAPwl7/8Zd773vdm7NixGTFiRE466aTcd9991fOVSiVLlizJ0UcfnREjRmT69Ol59NFH+3yNffv2pb29PY2NjRk9enTmzp2bAwcO1GE3AAC1MyAD8Fe/+lXOOOOMvOQlL8m3vvWtPPTQQ/mHf/iHHHXUUdU1y5cvz4oVK7J69eps3bo1I0eOzIwZM/Lkk09W17S3t+fBBx/Mxo0bs379+mzevDnz5s2r064AAGqjoVKpVOo9xP+vRYsW5fvf/37+7d/+7TnPVyqVtLa25kMf+lAuueSSJElXV1eam5uzdu3azJ49Ow8//HCmTJmSbdu2ZerUqUmSDRs25O1vf3t+8YtfpLW19X+co7u7O01NTenq6kpjY+MLvMvk2EXfeN5zP1s28wW/HgCUoL///B4IBuQdwK997WuZOnVq3vWud2XcuHE55ZRTcsMNN1TPP/bYY+ns7Mz06dOrx5qamjJt2rRs2bIlSbJly5aMHj26Gn9JMn369AwaNChbt259zuv29PSku7u7zwMAYKAZkAH405/+NKtWrcpxxx2X22+/Pe9///vzwQ9+MOvWrUuSdHZ2Jkmam5v7vK65ubl6rrOzM+PGjetzfsiQIRkzZkx1zbMtXbo0TU1N1cf48eP7aYcAAP1nQAZgb29vXve61+Xqq6/OKaecknnz5uXCCy/M6tWr+/W6ixcvTldXV/Wxe/fufr0eAEB/GJABePTRR2fKlCl9jh1//PHZtWtXkqSlpSVJsmfPnj5r9uzZUz3X0tKSvXv39jn/9NNPZ9++fdU1zzZs2LA0Njb2eQAADDQDMgDPOOOM7Ny5s8+xRx55JBMnTkySTJo0KS0tLdm0aVP1fHd3d7Zu3Zq2trYkSVtbW/bv35/t27dX19x5553p7e3NtGnTarYXAIBaG1LvAf4QF198cU4//fRcffXVefe735177703119/fa6//vokSUNDQxYsWJArr7wyxx13XCZNmpTLLrssra2tOffcc5Pf3DF829veVv3W8VNPPZX58+dn9uzZv9dPAAMADFQDMgBPO+203HrrrVm8eHE+9alPZdKkSbnuuuvS3t5eXfPhD384Bw8ezLx587J///684Q1vyIYNGzJ8+PDqmptuuinz58/PmWeemUGDBmXWrFlZsWJFnXYFAFAbA/JzAI8UPgcQAAYenwM4QN8DCADAH04AAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRGAAIAFEYAAgAURgACABRmSL0HgOdy7KJvPO+5ny2bWdNZAODFxh1AAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMIIQACAwghAAIDCCEAAgMK8KAJw2bJlaWhoyIIFC6rHnnzyyXR0dGTs2LF52ctellmzZmXPnj19Xrdr167MnDkzL33pSzNu3Lhceumlefrpp+uwAwCA2hnwAbht27Z84QtfyMknn9zn+MUXX5yvf/3r+cpXvpK77rorjz/+eN75zndWzx8+fDgzZ87MoUOHcvfdd2fdunVZu3ZtlixZUoddAADUzoAOwAMHDqS9vT033HBDjjrqqOrxrq6u3Hjjjbn22mvz1re+NaeeemrWrFmTu+++O/fcc0+S5I477shDDz2UL37xi3nta1+bs88+O1dccUVWrlyZQ4cO1XFXAAD9a0AHYEdHR2bOnJnp06f3Ob59+/Y89dRTfY5Pnjw5EyZMyJYtW5IkW7ZsyUknnZTm5ubqmhkzZqS7uzsPPvjgc16vp6cn3d3dfR4AAAPNkHoP8Ie65ZZb8oMf/CDbtm37nXOdnZ0ZOnRoRo8e3ed4c3NzOjs7q2t+O/6eOf/MueeydOnSXH755S/gLgAAam9A3gHcvXt3Lrrootx0000ZPnx4za67ePHidHV1VR+7d++u2bUBAF4oAzIAt2/fnr179+Z1r3tdhgwZkiFDhuSuu+7KihUrMmTIkDQ3N+fQoUPZv39/n9ft2bMnLS0tSZKWlpbf+angZ54/s+bZhg0blsbGxj4PAICBZkAG4Jlnnpn7778/O3bsqD6mTp2a9vb26v9+yUtekk2bNlVfs3PnzuzatSttbW1Jkra2ttx///3Zu3dvdc3GjRvT2NiYKVOm1GVfAAC1MCDfAzhq1KiceOKJfY6NHDkyY8eOrR6fO3duFi5cmDFjxqSxsTEf+MAH0tbWlte//vVJkrPOOitTpkzJeeedl+XLl6ezszMf//jH09HRkWHDhtVlXwAAtTAgA/D38elPfzqDBg3KrFmz0tPTkxkzZuTzn/989fzgwYOzfv36vP/9709bW1tGjhyZOXPm5FOf+lRd5wYA6G8vmgD87ne/2+f58OHDs3LlyqxcufJ5XzNx4sR885vfrMF0AABHjgH5HkAAAP5wAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwAhAAoDACEACgMAIQAKAwQ+o9ANDXsYu+8bznfrZsZk1nAeDFyR1AAIDCDMgAXLp0aU477bSMGjUq48aNy7nnnpudO3f2WfPkk0+mo6MjY8eOzcte9rLMmjUre/bs6bNm165dmTlzZl760pdm3LhxufTSS/P000/XeDcAALU1IAPwrrvuSkdHR+65555s3LgxTz31VM4666wcPHiwuubiiy/O17/+9XzlK1/JXXfdlccffzzvfOc7q+cPHz6cmTNn5tChQ7n77ruzbt26rF27NkuWLKnTrgAAamNAvgdww4YNfZ6vXbs248aNy/bt2/OmN70pXV1dufHGG3PzzTfnrW99a5JkzZo1Of7443PPPffk9a9/fe6444489NBD+fa3v53m5ua89rWvzRVXXJGPfOQj+eQnP5mhQ4fWaXcAAP1rQN4BfLaurq4kyZgxY5Ik27dvz1NPPZXp06dX10yePDkTJkzIli1bkiRbtmzJSSedlObm5uqaGTNmpLu7Ow8++GDN9wAAUCsD8g7gb+vt7c2CBQtyxhln5MQTT0ySdHZ2ZujQoRk9enSftc3Nzens7Kyu+e34e+b8M+eeS09PT3p6eqrPu7u7X/D9AAD0twF/B7CjoyMPPPBAbrnlln6/1tKlS9PU1FR9jB8/vt+vCQDwQhvQATh//vysX78+3/nOd3LMMcdUj7e0tOTQoUPZv39/n/V79uxJS0tLdc2zfyr4mefPrHm2xYsXp6urq/rYvXt3P+wKAKB/DcgArFQqmT9/fm699dbceeedmTRpUp/zp556al7ykpdk06ZN1WM7d+7Mrl270tbWliRpa2vL/fffn71791bXbNy4MY2NjZkyZcpzXnfYsGFpbGzs8wAAGGgG5HsAOzo6cvPNN+erX/1qRo0aVX3PXlNTU0aMGJGmpqbMnTs3CxcuzJgxY9LY2JgPfOADaWtry+tf//okyVlnnZUpU6bkvPPOy/Lly9PZ2ZmPf/zj6ejoyLBhw+q8QwCA/jMgA3DVqlVJkj/7sz/rc3zNmjV53/velyT59Kc/nUGDBmXWrFnp6enJjBkz8vnPf766dvDgwVm/fn3e//73p62tLSNHjsycOXPyqU99qsa7AQCorQEZgJVK5X9cM3z48KxcuTIrV6583jUTJ07MN7/5zRd4OgCAI9uAfA8gAAB/OAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUBgBCABQGAEIAFAYAQgAUJgh9R4A4Lcdu+gbz3vuZ8tm1nQWgBcrdwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKIwABAAojAAEACiMAAQAKM6TeAwDwfx276BvPe+5ny2bWdBbgxc0dQACAwghAAIDCFB+AK1euzLHHHpvhw4dn2rRpuffee+s9EgBAvyr6PYBf+tKXsnDhwqxevTrTpk3LddddlxkzZmTnzp0ZN25cvccDKIr3QELtFH0H8Nprr82FF16Y888/P1OmTMnq1avz0pe+NP/4j/9Y79EAAPpNsXcADx06lO3bt2fx4sXVY4MGDcr06dOzZcuW53xNT09Penp6qs+7urqSJN3d3f0yY2/Pr5/3XH9d80hh78/txb73FL7/kvce+y/WiZ+4/XnPPXD5jH655jO/nyqVSr98/YGg2AD87//+7xw+fDjNzc19jjc3N+fHP/7xc75m6dKlufzyy3/n+Pjx4/ttzufTdF3NL3nEsPdylbz/kvce+y9Wf////sQTT6Spqal/L3KEKjYA/xCLFy/OwoULq897e3uzb9++jB07Ng0NDS/otbq7uzN+/Pjs3r07jY2NL+jXHgjsv+z9x69B8fuPXwP778f9VyqVPPHEE2ltbX1Bv+5AUmwAvvzlL8/gwYOzZ8+ePsf37NmTlpaW53zNsGHDMmzYsD7HRo8e3a9zNjY2FvkP/jPsv+z9x69B8fuPXwP776f9l3rn7xnF/hDI0KFDc+qpp2bTpk3VY729vdm0aVPa2trqOhsAQH8q9g5gkixcuDBz5szJ1KlT86d/+qe57rrrcvDgwZx//vn1Hg0AoN8UHYDvec978l//9V9ZsmRJOjs789rXvjYbNmz4nR8MqYdhw4blE5/4xO98y7kU9l/2/uPXoPj9x6+B/Re+//7WUCn5Z6ABAApU7HsAAQBKJQABAAojAAEACiMAAQAKIwCPMJs3b84555yT1tbWNDQ05Lbbbqv3SDW1dOnSnHbaaRk1alTGjRuXc889Nzt37qz3WDWzatWqnHzyydUPPm1ra8u3vvWteo9VN8uWLUtDQ0MWLFhQ71Fq5pOf/GQaGhr6PCZPnlzvsWrql7/8Zd773vdm7NixGTFiRE466aTcd9999R6rZo499tjf+T3Q0NCQjo6Oeo9WE4cPH85ll12WSZMmZcSIEXnVq16VK664oui/t7c/FP0xMEeigwcP5jWveU0uuOCCvPOd76z3ODV31113paOjI6eddlqefvrpfPSjH81ZZ52Vhx56KCNHjqz3eP3umGOOybJly3LcccelUqlk3bp1ecc73pEf/vCHOeGEE+o9Xk1t27YtX/jCF3LyySfXe5SaO+GEE/Ltb3+7+nzIkHL+Vf2rX/0qZ5xxRt7ylrfkW9/6Vl7xilfk0UcfzVFHHVXv0Wpm27ZtOXz4cPX5Aw88kD//8z/Pu971rrrOVSvXXHNNVq1alXXr1uWEE07Ifffdl/PPPz9NTU354Ac/WO/xXjTK+bfKAHH22Wfn7LPPrvcYdbNhw4Y+z9euXZtx48Zl+/btedOb3lS3uWrlnHPO6fP8qquuyqpVq3LPPfcUFYAHDhxIe3t7brjhhlx55ZX1HqfmhgwZ8rx/JeWL3TXXXJPx48dnzZo11WOTJk2q60y19opXvKLP82XLluVVr3pV3vzmN9dtplq6++678453vCMzZ85MfnNH9J/+6Z9y77331nu0FxXfAuaI1tXVlSQZM2ZMvUepucOHD+eWW27JwYMHi/vrCTs6OjJz5sxMnz693qPUxaOPPprW1ta88pWvTHt7e3bt2lXvkWrma1/7WqZOnZp3vetdGTduXE455ZTccMMN9R6rbg4dOpQvfvGLueCCC9LQ0FDvcWri9NNPz6ZNm/LII48kSX70ox/le9/7XtE3R/qDO4AcsXp7e7NgwYKcccYZOfHEE+s9Ts3cf//9aWtry5NPPpmXvexlufXWWzNlypR6j1Uzt9xyS37wgx9k27Zt9R6lLqZNm5a1a9fm1a9+df7zP/8zl19+ed74xjfmgQceyKhRo+o9Xr/76U9/mlWrVmXhwoX56Ec/mm3btuWDH/xghg4dmjlz5tR7vJq77bbbsn///rzvfe+r9yg1s2jRonR3d2fy5MkZPHhwDh8+nKuuuirt7e31Hu1FRQByxOro6MgDDzyQ733ve/UepaZe/epXZ8eOHenq6so///M/Z86cObnrrruKiMDdu3fnoosuysaNGzN8+PB6j1MXv32X4+STT860adMyceLEfPnLX87cuXPrOlst9Pb2ZurUqbn66quTJKecckoeeOCBrF69usgAvPHGG3P22WentbW13qPUzJe//OXcdNNNufnmm3PCCSdkx44dWbBgQVpbW4v8PdBfBCBHpPnz52f9+vXZvHlzjjnmmHqPU1NDhw7NH//xHydJTj311Gzbti2f+cxn8oUvfKHeo/W77du3Z+/evXnd615XPXb48OFs3rw5n/vc59LT05PBgwfXdcZaGz16dP7kT/4kP/nJT+o9Sk0cffTRv/MfO8cff3z+5V/+pW4z1cvPf/7zfPvb386//uu/1nuUmrr00kuzaNGizJ49O0ly0kkn5ec//3mWLl0qAF9AApAjSqVSyQc+8IHceuut+e53v1vcm7+fS29vb3p6euo9Rk2ceeaZuf/++/scO//88zN58uR85CMfKS7+8psfiPmP//iPnHfeefUepSbOOOOM3/nop0ceeSQTJ06s20z1smbNmowbN676wxCl+PWvf51Bg/r+iMLgwYPT29tbt5lejATgEebAgQN9/kv/sccey44dOzJmzJhMmDChrrPVQkdHR26++eZ89atfzahRo9LZ2ZkkaWpqyogRI+o9Xr9bvHhxzj777EyYMCFPPPFEbr755nz3u9/N7bffXu/RamLUqFG/837PkSNHZuzYscW8D/SSSy7JOeeck4kTJ+bxxx/PJz7xiQwePDh/+Zd/We/RauLiiy/O6aefnquvvjrvfve7c++99+b666/P9ddfX+/Raqq3tzdr1qzJnDlzivoYoPzm0xCuuuqqTJgwISeccEJ++MMf5tprr80FF1xQ79FeXCocUb7zne9UkvzOY86cOfUerSaea+9JKmvWrKn3aDVxwQUXVCZOnFgZOnRo5RWveEXlzDPPrNxxxx31Hquu3vzmN1cuuuiieo9RM+95z3sqRx99dGXo0KGVP/qjP6q85z3vqfzkJz+p91g19fWvf71y4oknVoYNG1aZPHly5frrr6/3SDV3++23V5JUdu7cWe9Raq67u7ty0UUXVSZMmFAZPnx45ZWvfGXlYx/7WKWnp6feo72oNFR8tDYAQFF8DiAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGEEIABAYQQgAEBhBCAAQGH+D0qdXUuXgornAAAAAElFTkSuQmCC",
+ "text/html": [
+ "\n",
+ " <div style=\"display: inline-block;\">\n",
+ " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
+ " Figure\n",
+ " </div>\n",
+ " <img src='' width=640.0/>\n",
+ " </div>\n",
+ " "
+ ],
+ "text/plain": [
+ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
- "print(f\"dataset size = {os.path.getsize('data.pkl')//(1024*1024)}MB\")"
+ "with plt.ioff():\n",
+ " plt.hist(tune_data['labels'],bins=64)\n",
+ " plt.show()"
]
},
{
@@ -321,16 +338,22 @@
},
{
"cell_type": "code",
- "execution_count": 107,
+ "execution_count": 51,
+ "execution_state": "idle",
"metadata": {
"id": "tLOWhg_CeWzH"
},
"outputs": [],
"source": [
"class TransformerModel(nn.Module):\n",
- " def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, device, dropout=0.1):\n",
+ " def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, device, dropout):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(input_dim, model_dim//2)\n",
+ " # seq_len is odd\n",
+ " self.fancy_encoding = torch.repeat_interleave(torch.rand((1, seq_len // 2 + 1, model_dim // 2), device=device), 2, dim=1)\n",
+ " # cut off last element since the target vertex is not repeated\n",
+ " self.fancy_encoding = self.fancy_encoding[:, :seq_len, :]\n",
+ " \n",
" self.model_dim = model_dim\n",
" self.seq_len = seq_len\n",
" self.device = device\n",
@@ -341,29 +364,17 @@
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
"\n",
" self.fc_out = nn.Linear(model_dim*seq_len, output_dim)\n",
- " self.fancy_encoding = torch.repeat_interleave(torch.rand((1,SEQ_LEN // 2, model_dim // 2), device=device), 2, dim=1)\n",
- " \n",
- " def positional_encoding(self, batch_size):\n",
- " position = torch.arange(self.seq_len, dtype=torch.float, device=self.device).unsqueeze(1)\n",
- " div_term = torch.exp(torch.arange(0, self.model_dim, 2, dtype=torch.float, device=self.device) *\n",
- " -(torch.log(torch.tensor(500.0)) / self.model_dim))\n",
- "\n",
- " pos_encoding = torch.zeros(self.seq_len, self.model_dim, device=self.device)\n",
- " pos_encoding[:, 0::2] = torch.sin(position * div_term)\n",
- " pos_encoding[:, 1::2] = torch.cos(position * div_term)\n",
- " pos_encoding = pos_encoding.unsqueeze(0).repeat(batch_size, 1, 1)\n",
- " return pos_encoding\n",
"\n",
" def forward(self, src, key_padding_mask):\n",
" batch_size, src_len = src.size(0), src.size(1)\n",
- " # src_pos = self.positional_encoding(batch_size)\n",
" embed = self.embedding(src)\n",
- " src = torch.cat((embed * sqrt(self.model_dim), torch.Tensor.repeat(self.fancy_encoding, (batch_size, 1, 1))), dim=2)\n",
+ " src = torch.cat((embed * sqrt(self.model_dim), self.fancy_encoding.repeat((batch_size, 1, 1))), dim=2)\n",
"\n",
- " output = self.transformer_encoder(src, None, src_key_padding_mask=key_padding_mask)\n",
+ " output = self.transformer_encoder(src, src_key_padding_mask=key_padding_mask)\n",
+ " output[key_padding_mask] = 0 # Hack to stop no_grad problem\n",
" flat_output = torch.flatten(output, start_dim=1, end_dim=2)\n",
" output = self.fc_out(flat_output)\n",
- " return output\n"
+ " return output"
]
},
{
@@ -377,7 +388,8 @@
},
{
"cell_type": "code",
- "execution_count": 121,
+ "execution_count": 78,
+ "execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -390,18 +402,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Trainable parameters in the model: 102K\n",
- "train BASELINEs: 0.1290\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/tmp/ipykernel_390590/1991115476.py:23: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " train_data_tensor = torch.tensor(train_data1, dtype=torch.long, device=device)\n",
- "/tmp/ipykernel_390590/1991115476.py:31: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " test_data_tensor = torch.tensor(test_data1, dtype=torch.long, device=device)\n"
+ "Trainable parameters in the model: 2390K\n"
]
}
],
@@ -410,49 +411,56 @@
"assert device.type == 'cuda', \"CUDA is not available. Please check your GPU setup.\"\n",
"\n",
"# PARAMS\n",
- "VOCAB_SIZE = 1+MAX_VTXS # one more than the max number of vertices\n",
- "MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
- "NEPOCHS = 50\n",
- "BSZ = 512\n",
- "LR = 0.001\n",
+ "VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n",
+ "MODEL_DIM = 256 # Dimension of model (embedding and transformer)\n",
+ "NEPOCHS = 1000\n",
+ "BSZ = 3072\n",
+ "LR = 0.003\n",
+ "WD = 0.002\n",
"NHEADS = 4\n",
- "NLAYERS = 2\n",
+ "NLAYERS = 3\n",
"PAD_TOKEN = 0\n",
+ "DROPOUT = 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
" num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
- " device=device).to(device)\n",
+ " dropout=DROPOUT, device=device).to(device)\n",
"\n",
"with open(\"data.pkl\", \"rb\") as f:\n",
- " data = pickle.load(f)\n",
- "\n",
- "train_data1 = data[\"train1-data\"]\n",
- "train_label1 = data[\"train1-labels\"]\n",
- "train_data_tensor = torch.tensor(train_data1, dtype=torch.long, device=device)\n",
- "train_label_tensor = torch.tensor(train_label1, dtype=torch.float, device=device)\n",
- "train_padding_mask = (train_data_tensor == PAD_TOKEN).bool().to(device)\n",
- "train_dataset = TensorDataset(train_data_tensor, train_label_tensor, train_padding_mask)\n",
- "train_loader = DataLoader(train_dataset, batch_size=BSZ, shuffle=True)\n",
+ " pickled_stuff = pickle.load(f)\n",
"\n",
- "test_data1 = data[\"test1-data\"]\n",
- "test_label1 = data[\"test1-labels\"]\n",
- "test_data_tensor = torch.tensor(test_data1, dtype=torch.long, device=device)\n",
- "test_label_tensor = torch.tensor(test_label1, dtype=torch.float, device=device)\n",
- "test_padding_mask = (test_data_tensor == PAD_TOKEN).bool().to(device)\n",
- "test_dataset = TensorDataset(test_data_tensor, test_label_tensor, test_padding_mask)\n",
+ "data = pickled_stuff[\"data\"].to(device)\n",
+ "label = pickled_stuff[\"labels\"].to(device)\n",
+ "padding_mask = (data == PAD_TOKEN).bool().to(device)\n",
+ "dataset = TensorDataset(data, label, padding_mask)\n",
+ "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [.8, .2])\n",
+ "train_loader = DataLoader(train_dataset, batch_size=BSZ, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=BSZ, shuffle=True)\n",
"\n",
"criterion = nn.MSELoss()\n",
- "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
- "\n",
- "train_err = []\n",
- "test_err = []\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)\n",
"\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
- "print(f\"Trainable parameters in the model: {trainable_params//1000}K\")\n",
- "\n",
- "train_baseline = ((train_label_tensor - train_label_tensor.mean())**2).mean().item()\n",
- "print(f\"train BASELINEs: {train_baseline:.4f}\")"
+ "print(f\"Trainable parameters in the model: {trainable_params//1000}K\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "execution_state": "idle",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor(141.4637, device='cuda:0')\n"
+ ]
+ }
+ ],
+ "source": [
+ "baseline_error = criterion(label, torch.tensor(1.5, dtype=torch.float32, device=device))\n",
+ "print(baseline_error)"
]
},
{
@@ -466,7 +474,65 @@
},
{
"cell_type": "code",
- "execution_count": 122,
+ "execution_count": 60,
+ "execution_state": "idle",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def evaluate():\n",
+ " model.eval()\n",
+ " test_loss = 0\n",
+ " with torch.no_grad():\n",
+ " for batch_src, batch_labels, batch_padding_mask in test_loader:\n",
+ " output = model(batch_src, batch_padding_mask)\n",
+ " loss = criterion(output.squeeze(1), batch_labels)\n",
+ " test_loss += loss.item()/len(test_loader)\n",
+ " return test_loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "execution_state": "idle",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "329425e6ee6d4189aefee350eba741c7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "image/png": "",
+ "text/html": [
+ "\n",
+ " <div style=\"display: inline-block;\">\n",
+ " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
+ " Figure\n",
+ " </div>\n",
+ " <img src='' width=640.0/>\n",
+ " </div>\n",
+ " "
+ ],
+ "text/plain": [
+ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# This has to be in a separate cell for some weird event loop reasons\n",
+ "fig,ax = plt.subplots()\n",
+ "fig.suptitle('MSE vs Epochs')\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -480,74 +546,173 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 1/50 \t Train Err: 0.1621 \t Test Err: 0.1208 \t baseline err: 0.1290\n",
- "Epoch 2/50 \t Train Err: 0.1266 \t Test Err: 0.1201 \t baseline err: 0.1290\n",
- "Epoch 3/50 \t Train Err: 0.1224 \t Test Err: 0.1199 \t baseline err: 0.1290\n",
- "Epoch 4/50 \t Train Err: 0.1190 \t Test Err: 0.1214 \t baseline err: 0.1290\n",
- "Epoch 5/50 \t Train Err: 0.1167 \t Test Err: 0.1164 \t baseline err: 0.1290\n",
- "Epoch 6/50 \t Train Err: 0.1154 \t Test Err: 0.1156 \t baseline err: 0.1290\n",
- "Epoch 7/50 \t Train Err: 0.1146 \t Test Err: 0.1131 \t baseline err: 0.1290\n",
- "Epoch 8/50 \t Train Err: 0.1140 \t Test Err: 0.1145 \t baseline err: 0.1290\n",
- "Epoch 9/50 \t Train Err: 0.1135 \t Test Err: 0.1144 \t baseline err: 0.1290\n",
- "Epoch 10/50 \t Train Err: 0.1134 \t Test Err: 0.1160 \t baseline err: 0.1290\n",
- "Epoch 11/50 \t Train Err: 0.1134 \t Test Err: 0.1160 \t baseline err: 0.1290\n",
- "Epoch 12/50 \t Train Err: 0.1129 \t Test Err: 0.1137 \t baseline err: 0.1290\n",
- "Epoch 13/50 \t Train Err: 0.1131 \t Test Err: 0.1122 \t baseline err: 0.1290\n",
- "Epoch 14/50 \t Train Err: 0.1125 \t Test Err: 0.1133 \t baseline err: 0.1290\n",
- "Epoch 15/50 \t Train Err: 0.1121 \t Test Err: 0.1119 \t baseline err: 0.1290\n",
- "Epoch 16/50 \t Train Err: 0.1120 \t Test Err: 0.1129 \t baseline err: 0.1290\n",
- "Epoch 17/50 \t Train Err: 0.1123 \t Test Err: 0.1123 \t baseline err: 0.1290\n",
- "Epoch 18/50 \t Train Err: 0.1120 \t Test Err: 0.1119 \t baseline err: 0.1290\n",
- "Epoch 19/50 \t Train Err: 0.1117 \t Test Err: 0.1148 \t baseline err: 0.1290\n",
- "Epoch 20/50 \t Train Err: 0.1119 \t Test Err: 0.1136 \t baseline err: 0.1290\n",
- "Epoch 21/50 \t Train Err: 0.1117 \t Test Err: 0.1120 \t baseline err: 0.1290\n",
- "Epoch 22/50 \t Train Err: 0.1114 \t Test Err: 0.1123 \t baseline err: 0.1290\n",
- "Epoch 23/50 \t Train Err: 0.1111 \t Test Err: 0.1121 \t baseline err: 0.1290\n",
- "Epoch 24/50 \t Train Err: 0.1093 \t Test Err: 0.1061 \t baseline err: 0.1290\n",
- "Epoch 25/50 \t Train Err: 0.1044 \t Test Err: 0.1012 \t baseline err: 0.1290\n",
- "Epoch 26/50 \t Train Err: 0.1012 \t Test Err: 0.1003 \t baseline err: 0.1290\n",
- "Epoch 27/50 \t Train Err: 0.0985 \t Test Err: 0.0964 \t baseline err: 0.1290\n",
- "Epoch 28/50 \t Train Err: 0.0957 \t Test Err: 0.0942 \t baseline err: 0.1290\n",
- "Epoch 29/50 \t Train Err: 0.0947 \t Test Err: 0.0935 \t baseline err: 0.1290\n",
- "Epoch 30/50 \t Train Err: 0.0931 \t Test Err: 0.0941 \t baseline err: 0.1290\n",
- "Epoch 31/50 \t Train Err: 0.0920 \t Test Err: 0.0916 \t baseline err: 0.1290\n",
- "Epoch 32/50 \t Train Err: 0.0893 \t Test Err: 0.0857 \t baseline err: 0.1290\n",
- "Epoch 33/50 \t Train Err: 0.0868 \t Test Err: 0.0814 \t baseline err: 0.1290\n",
- "Epoch 34/50 \t Train Err: 0.0827 \t Test Err: 0.0785 \t baseline err: 0.1290\n",
- "Epoch 35/50 \t Train Err: 0.0770 \t Test Err: 0.0720 \t baseline err: 0.1290\n",
- "Epoch 36/50 \t Train Err: 0.0713 \t Test Err: 0.0646 \t baseline err: 0.1290\n",
- "Epoch 37/50 \t Train Err: 0.0642 \t Test Err: 0.0540 \t baseline err: 0.1290\n",
- "Epoch 38/50 \t Train Err: 0.0588 \t Test Err: 0.0501 \t baseline err: 0.1290\n",
- "Epoch 39/50 \t Train Err: 0.0543 \t Test Err: 0.0456 \t baseline err: 0.1290\n",
- "Epoch 40/50 \t Train Err: 0.0488 \t Test Err: 0.0366 \t baseline err: 0.1290\n",
- "Epoch 41/50 \t Train Err: 0.0416 \t Test Err: 0.0315 \t baseline err: 0.1290\n",
- "Epoch 42/50 \t Train Err: 0.0360 \t Test Err: 0.0214 \t baseline err: 0.1290\n",
- "Epoch 43/50 \t Train Err: 0.0305 \t Test Err: 0.0172 \t baseline err: 0.1290\n",
- "Epoch 44/50 \t Train Err: 0.0239 \t Test Err: 0.0116 \t baseline err: 0.1290\n",
- "Epoch 45/50 \t Train Err: 0.0205 \t Test Err: 0.0117 \t baseline err: 0.1290\n",
- "Epoch 46/50 \t Train Err: 0.0181 \t Test Err: 0.0092 \t baseline err: 0.1290\n",
- "Epoch 47/50 \t Train Err: 0.0164 \t Test Err: 0.0100 \t baseline err: 0.1290\n",
- "Epoch 48/50 \t Train Err: 0.0155 \t Test Err: 0.0081 \t baseline err: 0.1290\n",
- "Epoch 49/50 \t Train Err: 0.0141 \t Test Err: 0.0074 \t baseline err: 0.1290\n",
- "Epoch 50/50 \t Train Err: 0.0129 \t Test Err: 0.0075 \t baseline err: 0.1290\n"
+ "Epoch 1/1000 \t Train Err: 88.4438 \t Test Err: 72.6466 \t baseline err: 141.4637\n",
+ "Epoch 2/1000 \t Train Err: 72.6060 \t Test Err: 75.4987 \t baseline err: 141.4637\n",
+ "Epoch 3/1000 \t Train Err: 72.6729 \t Test Err: 73.9506 \t baseline err: 141.4637\n",
+ "Epoch 6/1000 \t Train Err: 72.7866 \t Test Err: 72.4612 \t baseline err: 141.4637\n",
+ "Epoch 7/1000 \t Train Err: 73.0815 \t Test Err: 74.9404 \t baseline err: 141.4637\n",
+ "Epoch 8/1000 \t Train Err: 72.8054 \t Test Err: 73.7412 \t baseline err: 141.4637\n",
+ "Epoch 9/1000 \t Train Err: 73.0514 \t Test Err: 82.7393 \t baseline err: 141.4637\n",
+ "Epoch 10/1000 \t Train Err: 73.0291 \t Test Err: 72.2577 \t baseline err: 141.4637\n",
+ "Epoch 11/1000 \t Train Err: 73.0081 \t Test Err: 72.4944 \t baseline err: 141.4637\n",
+ "Epoch 12/1000 \t Train Err: 72.6500 \t Test Err: 75.2036 \t baseline err: 141.4637\n",
+ "Epoch 13/1000 \t Train Err: 72.6255 \t Test Err: 73.4970 \t baseline err: 141.4637\n",
+ "Epoch 14/1000 \t Train Err: 72.6803 \t Test Err: 74.0005 \t baseline err: 141.4637\n",
+ "Epoch 15/1000 \t Train Err: 72.7032 \t Test Err: 73.0177 \t baseline err: 141.4637\n",
+ "Epoch 16/1000 \t Train Err: 72.7891 \t Test Err: 75.1899 \t baseline err: 141.4637\n",
+ "Epoch 17/1000 \t Train Err: 74.0133 \t Test Err: 71.7237 \t baseline err: 141.4637\n",
+ "Epoch 18/1000 \t Train Err: 72.7520 \t Test Err: 75.8566 \t baseline err: 141.4637\n",
+ "Epoch 19/1000 \t Train Err: 72.5771 \t Test Err: 74.9531 \t baseline err: 141.4637\n",
+ "Epoch 20/1000 \t Train Err: 72.6114 \t Test Err: 73.3918 \t baseline err: 141.4637\n",
+ "Epoch 21/1000 \t Train Err: 71.5844 \t Test Err: 57.3829 \t baseline err: 141.4637\n",
+ "Epoch 22/1000 \t Train Err: 56.8166 \t Test Err: 60.2253 \t baseline err: 141.4637\n",
+ "Epoch 23/1000 \t Train Err: 58.2172 \t Test Err: 56.7333 \t baseline err: 141.4637\n",
+ "Epoch 24/1000 \t Train Err: 56.1189 \t Test Err: 55.7485 \t baseline err: 141.4637\n",
+ "Epoch 25/1000 \t Train Err: 55.5304 \t Test Err: 56.2083 \t baseline err: 141.4637\n",
+ "Epoch 26/1000 \t Train Err: 68.7059 \t Test Err: 72.6976 \t baseline err: 141.4637\n",
+ "Epoch 27/1000 \t Train Err: 72.7020 \t Test Err: 73.1029 \t baseline err: 141.4637\n",
+ "Epoch 28/1000 \t Train Err: 72.4459 \t Test Err: 73.5617 \t baseline err: 141.4637\n",
+ "Epoch 29/1000 \t Train Err: 72.5310 \t Test Err: 75.8304 \t baseline err: 141.4637\n",
+ "Epoch 30/1000 \t Train Err: 72.5256 \t Test Err: 73.0845 \t baseline err: 141.4637\n",
+ "Epoch 31/1000 \t Train Err: 72.4667 \t Test Err: 72.9080 \t baseline err: 141.4637\n",
+ "Epoch 32/1000 \t Train Err: 72.5369 \t Test Err: 72.6703 \t baseline err: 141.4637\n",
+ "Epoch 33/1000 \t Train Err: 72.4685 \t Test Err: 74.7614 \t baseline err: 141.4637\n",
+ "Epoch 34/1000 \t Train Err: 72.4926 \t Test Err: 74.0886 \t baseline err: 141.4637\n",
+ "Epoch 35/1000 \t Train Err: 71.3339 \t Test Err: 55.4380 \t baseline err: 141.4637\n",
+ "Epoch 36/1000 \t Train Err: 60.7870 \t Test Err: 59.4468 \t baseline err: 141.4637\n",
+ "Epoch 37/1000 \t Train Err: 55.6557 \t Test Err: 56.7001 \t baseline err: 141.4637\n",
+ "Epoch 38/1000 \t Train Err: 55.4896 \t Test Err: 55.8308 \t baseline err: 141.4637\n",
+ "Epoch 39/1000 \t Train Err: 55.6962 \t Test Err: 58.9664 \t baseline err: 141.4637\n",
+ "Epoch 40/1000 \t Train Err: 55.5519 \t Test Err: 57.7560 \t baseline err: 141.4637\n",
+ "Epoch 41/1000 \t Train Err: 55.5370 \t Test Err: 56.6866 \t baseline err: 141.4637\n",
+ "Epoch 42/1000 \t Train Err: 55.4300 \t Test Err: 56.3424 \t baseline err: 141.4637\n",
+ "Epoch 43/1000 \t Train Err: 55.4922 \t Test Err: 56.3748 \t baseline err: 141.4637\n",
+ "Epoch 44/1000 \t Train Err: 55.6073 \t Test Err: 59.0728 \t baseline err: 141.4637\n",
+ "Epoch 45/1000 \t Train Err: 55.5497 \t Test Err: 58.5533 \t baseline err: 141.4637\n",
+ "Epoch 46/1000 \t Train Err: 55.4837 \t Test Err: 57.2847 \t baseline err: 141.4637\n",
+ "Epoch 47/1000 \t Train Err: 55.4173 \t Test Err: 57.1441 \t baseline err: 141.4637\n",
+ "Epoch 48/1000 \t Train Err: 55.4576 \t Test Err: 55.7806 \t baseline err: 141.4637\n",
+ "Epoch 49/1000 \t Train Err: 55.5678 \t Test Err: 56.6457 \t baseline err: 141.4637\n",
+ "Epoch 50/1000 \t Train Err: 55.5537 \t Test Err: 60.0365 \t baseline err: 141.4637\n",
+ "Epoch 51/1000 \t Train Err: 55.5123 \t Test Err: 55.6848 \t baseline err: 141.4637\n",
+ "Epoch 52/1000 \t Train Err: 55.5872 \t Test Err: 55.8084 \t baseline err: 141.4637\n",
+ "Epoch 53/1000 \t Train Err: 55.4548 \t Test Err: 56.5655 \t baseline err: 141.4637\n",
+ "Epoch 54/1000 \t Train Err: 55.5124 \t Test Err: 56.3470 \t baseline err: 141.4637\n",
+ "Epoch 55/1000 \t Train Err: 55.4518 \t Test Err: 57.6169 \t baseline err: 141.4637\n",
+ "Epoch 56/1000 \t Train Err: 55.4073 \t Test Err: 55.6467 \t baseline err: 141.4637\n",
+ "Epoch 57/1000 \t Train Err: 55.4745 \t Test Err: 56.3436 \t baseline err: 141.4637\n",
+ "Epoch 58/1000 \t Train Err: 55.4862 \t Test Err: 56.2289 \t baseline err: 141.4637\n",
+ "Epoch 59/1000 \t Train Err: 55.5221 \t Test Err: 55.3599 \t baseline err: 141.4637\n",
+ "Epoch 60/1000 \t Train Err: 55.4843 \t Test Err: 55.3953 \t baseline err: 141.4637\n",
+ "Epoch 61/1000 \t Train Err: 55.5095 \t Test Err: 56.4781 \t baseline err: 141.4637\n",
+ "Epoch 62/1000 \t Train Err: 55.6532 \t Test Err: 56.4005 \t baseline err: 141.4637\n",
+ "Epoch 63/1000 \t Train Err: 55.5240 \t Test Err: 57.4780 \t baseline err: 141.4637\n",
+ "Epoch 64/1000 \t Train Err: 55.4915 \t Test Err: 55.8880 \t baseline err: 141.4637\n",
+ "Epoch 65/1000 \t Train Err: 55.4006 \t Test Err: 56.1770 \t baseline err: 141.4637\n",
+ "Epoch 66/1000 \t Train Err: 55.3153 \t Test Err: 56.3041 \t baseline err: 141.4637\n",
+ "Epoch 67/1000 \t Train Err: 55.3105 \t Test Err: 55.7897 \t baseline err: 141.4637\n",
+ "Epoch 68/1000 \t Train Err: 55.9038 \t Test Err: 54.9242 \t baseline err: 141.4637\n",
+ "Epoch 69/1000 \t Train Err: 55.4002 \t Test Err: 55.2162 \t baseline err: 141.4637\n",
+ "Epoch 70/1000 \t Train Err: 55.5265 \t Test Err: 54.4618 \t baseline err: 141.4637\n",
+ "Epoch 71/1000 \t Train Err: 55.4598 \t Test Err: 56.2988 \t baseline err: 141.4637\n",
+ "Epoch 72/1000 \t Train Err: 55.4995 \t Test Err: 55.1318 \t baseline err: 141.4637\n",
+ "Epoch 73/1000 \t Train Err: 55.5224 \t Test Err: 55.6233 \t baseline err: 141.4637\n",
+ "Epoch 74/1000 \t Train Err: 55.2633 \t Test Err: 55.0628 \t baseline err: 141.4637\n",
+ "Epoch 75/1000 \t Train Err: 55.3569 \t Test Err: 54.9321 \t baseline err: 141.4637\n",
+ "Epoch 76/1000 \t Train Err: 55.4845 \t Test Err: 55.7232 \t baseline err: 141.4637\n",
+ "Epoch 77/1000 \t Train Err: 55.3814 \t Test Err: 54.6657 \t baseline err: 141.4637\n",
+ "Epoch 78/1000 \t Train Err: 55.4396 \t Test Err: 55.2952 \t baseline err: 141.4637\n",
+ "Epoch 79/1000 \t Train Err: 55.4018 \t Test Err: 55.4081 \t baseline err: 141.4637\n",
+ "Epoch 80/1000 \t Train Err: 55.5015 \t Test Err: 56.3544 \t baseline err: 141.4637\n",
+ "Epoch 81/1000 \t Train Err: 55.5352 \t Test Err: 55.8122 \t baseline err: 141.4637\n",
+ "Epoch 82/1000 \t Train Err: 55.4454 \t Test Err: 54.7959 \t baseline err: 141.4637\n",
+ "Epoch 83/1000 \t Train Err: 55.4375 \t Test Err: 55.1435 \t baseline err: 141.4637\n",
+ "Epoch 84/1000 \t Train Err: 55.4614 \t Test Err: 54.7396 \t baseline err: 141.4637\n",
+ "Epoch 85/1000 \t Train Err: 55.4046 \t Test Err: 55.3768 \t baseline err: 141.4637\n",
+ "Epoch 86/1000 \t Train Err: 55.3655 \t Test Err: 54.7487 \t baseline err: 141.4637\n",
+ "Epoch 87/1000 \t Train Err: 55.4036 \t Test Err: 55.0165 \t baseline err: 141.4637\n",
+ "Epoch 88/1000 \t Train Err: 55.4548 \t Test Err: 55.5787 \t baseline err: 141.4637\n",
+ "Epoch 89/1000 \t Train Err: 55.3973 \t Test Err: 54.7695 \t baseline err: 141.4637\n",
+ "Epoch 90/1000 \t Train Err: 55.4891 \t Test Err: 56.6628 \t baseline err: 141.4637\n",
+ "Epoch 91/1000 \t Train Err: 55.5870 \t Test Err: 56.8348 \t baseline err: 141.4637\n",
+ "Epoch 92/1000 \t Train Err: 55.4630 \t Test Err: 55.4178 \t baseline err: 141.4637\n",
+ "Epoch 93/1000 \t Train Err: 55.5218 \t Test Err: 55.8851 \t baseline err: 141.4637\n",
+ "Epoch 94/1000 \t Train Err: 55.4550 \t Test Err: 55.9211 \t baseline err: 141.4637\n",
+ "Epoch 95/1000 \t Train Err: 55.4727 \t Test Err: 56.5819 \t baseline err: 141.4637\n",
+ "Epoch 96/1000 \t Train Err: 55.4301 \t Test Err: 57.2222 \t baseline err: 141.4637\n",
+ "Epoch 97/1000 \t Train Err: 55.4108 \t Test Err: 55.3496 \t baseline err: 141.4637\n",
+ "Epoch 98/1000 \t Train Err: 55.4733 \t Test Err: 55.8675 \t baseline err: 141.4637\n",
+ "Epoch 99/1000 \t Train Err: 55.3536 \t Test Err: 56.2623 \t baseline err: 141.4637\n",
+ "Epoch 100/1000 \t Train Err: 55.2286 \t Test Err: 54.2883 \t baseline err: 141.4637\n",
+ "Epoch 101/1000 \t Train Err: 54.6294 \t Test Err: 54.3795 \t baseline err: 141.4637\n",
+ "Epoch 102/1000 \t Train Err: 54.0334 \t Test Err: 52.9438 \t baseline err: 141.4637\n",
+ "Epoch 103/1000 \t Train Err: 55.8557 \t Test Err: 55.9887 \t baseline err: 141.4637\n",
+ "Epoch 104/1000 \t Train Err: 55.3523 \t Test Err: 55.6809 \t baseline err: 141.4637\n",
+ "Epoch 105/1000 \t Train Err: 54.8650 \t Test Err: 54.1854 \t baseline err: 141.4637\n",
+ "Epoch 106/1000 \t Train Err: 54.8108 \t Test Err: 54.9449 \t baseline err: 141.4637\n",
+ "Epoch 107/1000 \t Train Err: 54.4932 \t Test Err: 53.3353 \t baseline err: 141.4637\n",
+ "Epoch 108/1000 \t Train Err: 54.0328 \t Test Err: 54.7242 \t baseline err: 141.4637\n",
+ "Epoch 109/1000 \t Train Err: 52.3047 \t Test Err: 50.1951 \t baseline err: 141.4637\n",
+ "Epoch 110/1000 \t Train Err: 46.5255 \t Test Err: 36.4212 \t baseline err: 141.4637\n",
+ "Epoch 111/1000 \t Train Err: 35.6437 \t Test Err: 36.2409 \t baseline err: 141.4637\n",
+ "Epoch 112/1000 \t Train Err: 35.5794 \t Test Err: 36.0152 \t baseline err: 141.4637\n",
+ "Epoch 113/1000 \t Train Err: 49.8337 \t Test Err: 56.2286 \t baseline err: 141.4637\n",
+ "Epoch 114/1000 \t Train Err: 55.4618 \t Test Err: 55.5159 \t baseline err: 141.4637\n",
+ "Epoch 115/1000 \t Train Err: 48.2926 \t Test Err: 37.0193 \t baseline err: 141.4637\n",
+ "Epoch 116/1000 \t Train Err: 43.0000 \t Test Err: 42.3349 \t baseline err: 141.4637\n",
+ "Epoch 117/1000 \t Train Err: 36.4887 \t Test Err: 38.2387 \t baseline err: 141.4637\n",
+ "Epoch 118/1000 \t Train Err: 35.7809 \t Test Err: 37.3943 \t baseline err: 141.4637\n",
+ "Epoch 119/1000 \t Train Err: 35.7078 \t Test Err: 37.0414 \t baseline err: 141.4637\n",
+ "Epoch 120/1000 \t Train Err: 35.7773 \t Test Err: 37.2562 \t baseline err: 141.4637\n",
+ "Epoch 121/1000 \t Train Err: 36.1376 \t Test Err: 38.1538 \t baseline err: 141.4637\n",
+ "Epoch 122/1000 \t Train Err: 38.3734 \t Test Err: 38.6760 \t baseline err: 141.4637\n",
+ "Epoch 123/1000 \t Train Err: 36.6591 \t Test Err: 38.0746 \t baseline err: 141.4637\n",
+ "Epoch 124/1000 \t Train Err: 37.1041 \t Test Err: 38.8434 \t baseline err: 141.4637\n",
+ "Epoch 125/1000 \t Train Err: 37.0860 \t Test Err: 37.3192 \t baseline err: 141.4637\n",
+ "Epoch 126/1000 \t Train Err: 35.8674 \t Test Err: 36.7553 \t baseline err: 141.4637\n",
+ "Epoch 127/1000 \t Train Err: 35.8025 \t Test Err: 36.0833 \t baseline err: 141.4637\n",
+ "Epoch 128/1000 \t Train Err: 35.6595 \t Test Err: 36.1782 \t baseline err: 141.4637\n",
+ "Epoch 129/1000 \t Train Err: 35.6861 \t Test Err: 36.1815 \t baseline err: 141.4637\n",
+ "Epoch 130/1000 \t Train Err: 35.5056 \t Test Err: 36.4077 \t baseline err: 141.4637\n",
+ "Epoch 131/1000 \t Train Err: 35.5161 \t Test Err: 36.6100 \t baseline err: 141.4637\n",
+ "Epoch 132/1000 \t Train Err: 35.4981 \t Test Err: 36.1938 \t baseline err: 141.4637\n",
+ "Epoch 133/1000 \t Train Err: 35.4577 \t Test Err: 36.2723 \t baseline err: 141.4637\n",
+ "Epoch 134/1000 \t Train Err: 35.5685 \t Test Err: 36.0326 \t baseline err: 141.4637\n",
+ "Epoch 135/1000 \t Train Err: 35.4790 \t Test Err: 36.0421 \t baseline err: 141.4637\n",
+ "Epoch 136/1000 \t Train Err: 35.5135 \t Test Err: 35.9383 \t baseline err: 141.4637\n",
+ "Epoch 137/1000 \t Train Err: 35.4027 \t Test Err: 35.5105 \t baseline err: 141.4637\n",
+ "Epoch 138/1000 \t Train Err: 35.4036 \t Test Err: 35.1535 \t baseline err: 141.4637\n",
+ "Epoch 139/1000 \t Train Err: 35.5047 \t Test Err: 35.7687 \t baseline err: 141.4637\n",
+ "Epoch 140/1000 \t Train Err: 35.3858 \t Test Err: 35.5093 \t baseline err: 141.4637\n",
+ "Epoch 141/1000 \t Train Err: 35.3629 \t Test Err: 35.1772 \t baseline err: 141.4637\n",
+ "Epoch 142/1000 \t Train Err: 35.3714 \t Test Err: 35.0682 \t baseline err: 141.4637\n",
+ "Epoch 143/1000 \t Train Err: 35.3976 \t Test Err: 35.6668 \t baseline err: 141.4637\n",
+ "Epoch 144/1000 \t Train Err: 35.4523 \t Test Err: 35.2633 \t baseline err: 141.4637\n",
+ "Epoch 145/1000 \t Train Err: 35.3763 \t Test Err: 35.5053 \t baseline err: 141.4637\n",
+ "Epoch 146/1000 \t Train Err: 35.3906 \t Test Err: 35.1866 \t baseline err: 141.4637\n",
+ "Epoch 147/1000 \t Train Err: 35.4717 \t Test Err: 35.2663 \t baseline err: 141.4637\n",
+ "Epoch 148/1000 \t Train Err: 49.5998 \t Test Err: 54.9490 \t baseline err: 141.4637\n",
+ "Epoch 149/1000 \t Train Err: 56.0901 \t Test Err: 55.1923 \t baseline err: 141.4637\n",
+ "Epoch 150/1000 \t Train Err: 55.4467 \t Test Err: 54.8896 \t baseline err: 141.4637\n",
+ "Epoch 151/1000 \t Train Err: 55.4050 \t Test Err: 55.0999 \t baseline err: 141.4637\n",
+ "Epoch 152/1000 \t Train Err: 55.3775 \t Test Err: 55.0012 \t baseline err: 141.4637\n",
+ "Epoch 153/1000 \t Train Err: 55.3564 \t Test Err: 55.1625 \t baseline err: 141.4637\n",
+ "Epoch 154/1000 \t Train Err: 55.4071 \t Test Err: 55.5768 \t baseline err: 141.4637\n",
+ "Epoch 155/1000 \t Train Err: 55.1633 \t Test Err: 55.3753 \t baseline err: 141.4637\n",
+ "Epoch 156/1000 \t Train Err: 45.2929 \t Test Err: 37.6036 \t baseline err: 141.4637\n",
+ "Epoch 157/1000 \t Train Err: 36.1442 \t Test Err: 35.6718 \t baseline err: 141.4637\n",
+ "Epoch 158/1000 \t Train Err: 35.5764 \t Test Err: 35.4595 \t baseline err: 141.4637\n",
+ "Epoch 159/1000 \t Train Err: 35.3944 \t Test Err: 35.3251 \t baseline err: 141.4637\n"
]
- },
- {
- "data": {
- "image/png": "",
- "text/plain": [
- "<Figure size 1000x500 with 1 Axes>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
}
],
"source": [
+ "train_err = []\n",
+ "test_err = []\n",
+ "\n",
"for epoch in range(NEPOCHS):\n",
- " model.train() # set to training mode\n",
+ " model.train()\n",
" train_loss = 0\n",
- "\n",
" for batch_src, batch_labels, batch_padding_mask in train_loader:\n",
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
@@ -556,118 +721,107 @@
" loss.backward()\n",
" optimizer.step()\n",
"\n",
- " # Evaluate performance\n",
- " model.eval()\n",
- " test_loss = 0\n",
- "\n",
- " with torch.no_grad():\n",
- " for batch_src, batch_labels, batch_padding_mask in test_loader:\n",
- " output = model(batch_src, batch_padding_mask)\n",
- " loss = criterion(output.squeeze(1), batch_labels)\n",
- " test_loss += loss.item()/len(test_loader)\n",
- "\n",
+ " test_loss = evaluate()\n",
+ " \n",
" test_err.append(test_loss)\n",
" train_err.append(train_loss)\n",
- " print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f} \\t baseline err: {train_baseline:.4f}\")\n",
- "\n",
- "plt.figure(figsize=(10, 5))\n",
- "plt.plot(test_err, label='Test', color='red')\n",
- "plt.plot(train_err, label='Train', color='blue')\n",
- "plt.title('Accuracy vs Epochs')\n",
- "plt.xlabel('Epochs'); plt.ylabel('Accuracy')\n",
- "plt.legend(); plt.grid()\n",
- "plt.show()"
+ " ax.plot(train_err, label='Train', color='blue')\n",
+ " ax.plot(test_err, label='Test', color='red')\n",
+ " ax.set_xlabel('Epochs')\n",
+ " ax.set_ylabel('MSE')\n",
+ " fig.canvas.draw()\n",
+ " print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f} \\t baseline err: {baseline_error:.4f}\")\n",
+ "\n",
+ " if epoch % 100 == 99:\n",
+ " torch.save(model.state_dict(), f\"model_weights_{epoch}.pth\")"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "v1hCiItHDWxJ"
- },
- "outputs": [],
+ "execution_count": 80,
+ "execution_state": "idle",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "55.06520214080811"
+ ]
+ },
+ "execution_count": 80,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "## Q: why is this not working so well?\n",
- "\n",
- "## maybe first try a simpler problem: just give it points for distinguishing between distance 1 or not"
+ "evaluate()"
]
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": null,
+ "execution_state": "running",
"metadata": {
"id": "LoGEmM5lH7_A"
},
"outputs": [],
"source": [
"batch_src, batch_labels, batch_padding_mask = next(iter(train_loader))\n",
- "output = model(batch_src, batch_padding_mask)"
+ "output = model(batch_src, batch_padding_mask)\n",
+ "batch_src[0], batch_labels[0], output[0]"
]
},
{
"cell_type": "code",
- "execution_count": 33,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "hO8AhX3G7vF8",
- "outputId": "8f4a3ca6-db47-434d-95a4-4631bc73de62"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "0 \t nan\n",
- "0 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "0 \t nan\n",
- "1 \t nan\n",
- "0 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "0 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "0 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n",
- "1 \t nan\n"
- ]
- }
- ],
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
"source": [
- "for x,y in zip(batch_labels.tolist(), output.squeeze(1).tolist()):\n",
- " print(f\"{int(x)} \\t {y:.1f}\")"
+ "plt.hist(output.detach().cpu().numpy().flatten(),bins=32)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "id": "dRdUGbFmkPtK"
- },
+ "execution_state": "running",
+ "metadata": {},
"outputs": [],
"source": [
- "batch_src[2]"
+ "plt.hist(label.detach().cpu().numpy().flatten(),bins=32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.scatter(batch_labels.detach().cpu().numpy().flatten(),output.detach().cpu().numpy().flatten())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_src2, batch_labels2, batch_padding_mask2 = next(iter(test_loader))\n",
+ "output2 = model(batch_src2, batch_padding_mask2)\n",
+ "loss = criterion(output2.squeeze(1), batch_labels2)\n",
+ "batch_src2[0], batch_labels2[0], output2[0], loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.scatter(batch_labels2.detach().cpu().numpy().flatten(),output2.detach().cpu().numpy().flatten())"
]
},
{
@@ -680,6 +834,91 @@
]
},
{
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_TUNE_EPOCHS = 100\n",
+ "TUNE_BSZ = 1024\n",
+ "TUNE_LR = 0.003\n",
+ "TUNE_WD = 0.002\n",
+ "\n",
+ "with open(\"tune_data.pkl\", \"rb\") as f:\n",
+ " pickled_tune_stuff = pickle.load(f)\n",
+ "\n",
+ "tune_data = pickled_tune_stuff[\"data\"].to(device)\n",
+ "tune_label = pickled_tune_stuff[\"labels\"].to(device)\n",
+ "tune_padding_mask = (tune_data == PAD_TOKEN).bool().to(device)\n",
+ "tune_dataset = TensorDataset(tune_data, tune_label, tune_padding_mask)\n",
+ "tune_train_dataset, tune_test_dataset = torch.utils.data.random_split(tune_dataset, [.8, .2])\n",
+ "tune_train_loader = DataLoader(tune_train_dataset, batch_size=TUNE_BSZ, shuffle=True)\n",
+ "tune_test_loader = DataLoader(tune_test_dataset, batch_size=TUNE_BSZ, shuffle=True)\n",
+ "\n",
+ "tune_criterion = nn.MSELoss()\n",
+ "tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This has to be in a separate cell for some weird event loop reasons\n",
+ "%matplotlib widget\n",
+ "fig,ax = plt.subplots()\n",
+ "fig.suptitle('MSE vs Epochs')\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "execution_state": "running",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ax.clear()\n",
+ "\n",
+ "tune_train_err = []\n",
+ "tune_test_err = []\n",
+ "\n",
+ "for epoch in range(N_TUNE_EPOCHS):\n",
+ " model.train()\n",
+ " tune_train_loss = 0\n",
+ " for batch_src, batch_labels, batch_padding_mask in tune_train_loader:\n",
+ " optimizer.zero_grad()\n",
+ " output = model(batch_src, batch_padding_mask)\n",
+ " loss = criterion(output.squeeze(1), batch_labels)\n",
+ " tune_train_loss += loss.item()/len(tune_train_loader)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " model.eval()\n",
+ " tune_test_loss = 0\n",
+ " with torch.no_grad():\n",
+ " for batch_src, batch_labels, batch_padding_mask in tune_test_loader:\n",
+ " output = model(batch_src, batch_padding_mask)\n",
+ " loss = criterion(output.squeeze(1), batch_labels)\n",
+ " tune_test_loss += loss.item()/len(tune_test_loader)\n",
+ " \n",
+ " tune_test_err.append(tune_test_loss)\n",
+ " tune_train_err.append(tune_train_loss)\n",
+ " ax.plot(tune_train_err, label='Train', color='blue')\n",
+ " ax.plot(tune_test_err, label='Test', color='red')\n",
+ " ax.set_xlabel('Epochs')\n",
+ " ax.set_ylabel('MSE')\n",
+ " fig.canvas.draw()\n",
+ " print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f} \\t baseline err: {baseline_error:.4f}\")\n",
+ "\n",
+ " if epoch % 100 == 9:\n",
+ " torch.save(model.state_dict(), f\"tune_model_weights_{epoch}.pth\")"
+ ]
+ },
+ {
"cell_type": "markdown",
"metadata": {
"id": "JtTLXn4zC1z_"
@@ -709,7 +948,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.12"
+ "version": "3.12.7"
}
},
"nbformat": 4,