diff options
author | SIPB | 2024-12-02 16:54:41 -0500 |
---|---|---|
committer | SIPB | 2024-12-02 16:54:41 -0500 |
commit | a24288e28c4b53fdd6467ed4eed626fa0586bf72 (patch) | |
tree | 47b3c91610c0dfff1ebe31611897a443ba85630b /insane-shortest-paths.ipynb | |
parent | 58e39027654343df3c140bf120194af1b3d5c159 (diff) |
Latest copy of blog post and insane TSP
Diffstat (limited to 'insane-shortest-paths.ipynb')
-rw-r--r-- | insane-shortest-paths.ipynb | 263 |
1 files changed, 263 insertions, 0 deletions
diff --git a/insane-shortest-paths.ipynb b/insane-shortest-paths.ipynb new file mode 100644 index 0000000..72846c2 --- /dev/null +++ b/insane-shortest-paths.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "execution_state": "idle", + "id": "86ce5f44-94f6-43b0-a0d1-091b8134ffb6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[set(), set(), {5, 6}, {4}, {3}, {2, 6}, {2, 5}]\n", + "[set(), {6}, set(), {4, 5, 6}, {3}, {3, 6}, {1, 3, 5}]\n", + "[set(), {4}, set(), {4, 5}, {1, 3}, {3, 6}, {5}]\n", + "[set(), {2, 6}, {1, 6}, {6}, set(), set(), {1, 2, 3}]\n", + "[set(), {3}, {3}, {1, 2, 5, 6}, {5}, {3, 4}, {3}]\n", + "[set(), {3, 6}, {4}, {1}, {2}, {6}, {1, 5}]\n", + "[set(), {2, 3}, {1, 3, 6}, {1, 2, 4}, {3}, set(), {2}]\n", + "[set(), {4}, set(), {4}, {1, 3, 5, 6}, {4}, {4}]\n", + "[set(), {3, 4, 5}, {6}, {1}, {1, 6}, {1}, {2, 4}]\n", + "[set(), {5, 6}, {6}, {6}, {5, 6}, {1, 4}, {1, 2, 3, 4}]\n" + ] + } + ], + "source": [ + "# -*- coding: utf-8 -*-\n", + "\"\"\"how-tsp-should-be.ipynb\n", + "\n", + "Automatically generated by Colab.\n", + "\n", + "Original file is located at\n", + " https://colab.research.google.com/drive/1InE1iW8ARzndPpvqH_9y22s81sOiHxPs\n", + "\"\"\"\n", + "\n", + "from tqdm import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "from math import sqrt\n", + "from collections import deque\n", + "import os\n", + "import random\n", + "import pickle\n", + "import ipdb\n", + "\n", + "# torch.manual_seed(30)\n", + "# random.seed(30)\n", + "torch.manual_seed(33)\n", + "random.seed(33)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "# assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\"\n", + "\n", + "NVTXS = 6\n", + "MAXDIST = NVTXS+1\n", + "AVGDEG = 2\n", + "SEQLEN = NVTXS + 1\n", + "HIDDENDIM = 4*NVTXS+2\n", + "\n", + "# 0: ANSFLAG\n", + "# 1:NVTXS+1 NBRS\n", + "# NVTXS+1: 2*NVTXS+1 REACH\n", + "# 2*NVTXS+1: 3*NVTXS+1 SELF\n", + "# -1 NOTANSFLAG\n", + "\n", + "START_REACH = NVTXS+1\n", + "START_OUT = 2*NVTXS+1\n", + "START_SELF = 3*NVTXS+1\n", + "SRC_FLAG_IDX = START_SELF\n", + "SOURCE = 1\n", + "TARGET = 2\n", + "ANS_FLAG_IDX = 0\n", + "NOTANS_FLAG_IDX = -1\n", + "\n", + "def print_everything(data):\n", + " print(\"NBRS\")\n", + " print(data[0, 1:, 1:1+NVTXS])\n", + " print(\"REACH\")\n", + " print(data[0, 1:, START_REACH:START_REACH+NVTXS])\n", + " print(\"ANSFLAG\")\n", + " print(data[0, :, 0])\n", + " print(\"MORE FLAGS\")\n", + " print(data[0, :, -1])\n", + " print(\"SELF\")\n", + " print(data[0, 1:, START_SELF:START_SELF+NVTXS])\n", + " print(\"OUT\")\n", + " print(data[0, 0, START_OUT:START_OUT+NVTXS])\n", + "\n", + "\n", + "def random_graph():\n", + " data = torch.zeros((SEQLEN, HIDDENDIM))\n", + "\n", + " for i in range(1,NVTXS+1):\n", + " data[i, START_SELF-1+i] = 1\n", + "\n", + " adj_list = [set() for _ in range(SEQLEN)]\n", + " indices = [random.randint(1, NVTXS) for _ in range(AVGDEG * NVTXS)]\n", + " for i in range(0, len(indices), 2):\n", + " u = indices[i]\n", + " v = indices[i + 1]\n", + " if u != v:\n", + " data[v,u] = 1\n", + " data[u,v] = 1\n", + " data[v,NVTXS+u] = 1\n", + " data[u,NVTXS+v] = 1\n", + " adj_list[u].add(v)\n", + " adj_list[v].add(u)\n", + "\n", + " data[0, ANS_FLAG_IDX] = 1\n", + " data[1:, NOTANS_FLAG_IDX] = 1\n", + "\n", + " # TODO: this is kind of a hack\n", + " data[0, START_REACH:START_REACH+NVTXS] = 1\n", + " return data, adj_list\n", + "\n", + "\"\"\"\n", + "input: G, represented as an adjacency list\n", + "output: distance from SOURCE to TARGET\n", + "\"\"\"\n", + "def SSSP(G):\n", + " dist = [MAXDIST for _ in G]\n", + " dist[SOURCE] = 0\n", + " frontier = deque()\n", + " frontier.append(SOURCE)\n", + " while len(frontier) > 0:\n", + " vtx = frontier.popleft()\n", + " for x in G[vtx]:\n", + " if dist[x] == MAXDIST:\n", + " dist[x] = 1 + dist[vtx]\n", + " frontier.append(x)\n", + " if x == TARGET:\n", + " return dist[TARGET]\n", + " return MAXDIST\n", + "\n", + "def mkbatch(size):\n", + " graphs1 = []\n", + " distance1 = []\n", + "\n", + " for i in range(size):\n", + " data, adj_list = random_graph()\n", + " dist = SSSP(adj_list)\n", + " graphs1.append(data)\n", + " distance1.append(dist)\n", + "\n", + " print(adj_list)\n", + "\n", + " data = torch.stack(graphs1)\n", + " labels = torch.tensor(distance1, dtype=torch.float16)\n", + " return data, labels\n", + "\n", + "\"\"\"\n", + "TODO: WRAP EVERYTHING in nn.Parameter(torch.zeros((1, HIDDENDIM)))\n", + "and then do my perturbing parameters experiment\n", + "\n", + "TODO:\n", + " USE activation magic to bring everything back to the 0/1 realm instead of possibly being 0/2 valued\n", + "\"\"\"\n", + "\n", + "class SillyTransformer(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.most_KQVs = []\n", + " for head in range(1,NVTXS+1):\n", + " Q = torch.zeros((2, HIDDENDIM))\n", + " Q[0, START_REACH-1+head] = 1000\n", + " Q[1, NOTANS_FLAG_IDX] = 1\n", + "\n", + " K = torch.zeros((2, HIDDENDIM))\n", + " K[0, head] = 1\n", + " K[1, ANS_FLAG_IDX] = 200\n", + "\n", + " V = torch.zeros((NVTXS,HIDDENDIM))\n", + " for i in range(NVTXS):\n", + " V[i, START_SELF+i] = 1\n", + "\n", + " self.most_KQVs.append((K, Q, V))\n", + "\n", + " self.weird_KQVs = []\n", + " for layer in range(NVTXS):\n", + " K = torch.zeros((3, HIDDENDIM))\n", + " K[0, NOTANS_FLAG_IDX] = -1000\n", + " K[0, SRC_FLAG_IDX] = +1100\n", + " K[1, NOTANS_FLAG_IDX] = -1000\n", + " K[1, NVTXS+TARGET] = +1100\n", + " K[1, ANS_FLAG_IDX] = -1100\n", + " K[2, ANS_FLAG_IDX] = 10\n", + "\n", + " Q = torch.zeros((3, HIDDENDIM))\n", + " Q[:, ANS_FLAG_IDX] = 1\n", + "\n", + " V = torch.zeros((NVTXS, HIDDENDIM))\n", + " V[layer, SRC_FLAG_IDX] = 1\n", + "\n", + " self.weird_KQVs.append((K, Q, V))\n", + "\n", + " def forward(self, src):\n", + " for layer in range(NVTXS):\n", + " allKQVs = [self.weird_KQVs[layer]] + self.most_KQVs\n", + " head_outputs = []\n", + " for (K, Q, V) in allKQVs:\n", + " ksrc = torch.matmul(src, K.unsqueeze(0).transpose(-2, -1))\n", + " qsrc = torch.matmul(src, Q.unsqueeze(0).transpose(-2, -1))\n", + " vsrc = torch.matmul(src, V.unsqueeze(0).transpose(-2, -1))\n", + "\n", + " scores = torch.matmul(qsrc, ksrc.transpose(-2, -1))\n", + " attention_weights = torch.softmax(scores, dim=-1)\n", + " head_output = torch.matmul(attention_weights, vsrc)\n", + " head_outputs.append(head_output)\n", + "\n", + " new_reaches = sum(head_outputs[1:])\n", + " BSZ = new_reaches.shape[0]\n", + "\n", + " nodelta_nbrs = torch.zeros((BSZ, SEQLEN, NVTXS+1))\n", + " morepadlol = torch.zeros((BSZ, SEQLEN, 1+NVTXS))\n", + "\n", + " DIFF = torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)\n", + " src += torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)\n", + " src[:, :, START_REACH:START_REACH+NVTXS] = 2*torch.sigmoid(src[:,:, START_REACH:START_REACH+NVTXS]*1000)-1\n", + "\n", + " # print(\"SRC\")\n", + " # print_everything(src)\n", + "\n", + " canreach = src[:,0,START_OUT:START_OUT+NVTXS]\n", + " # __import__('ipdb').set_trace()\n", + " final_output = 1+torch.sum(1-canreach,dim=1)\n", + " return final_output\n", + "\n", + "model = SillyTransformer()\n", + "model.to(device)\n", + "\n", + "data, labels = mkbatch(10)\n", + "assert torch.all(model(data) == labels)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |