aboutsummaryrefslogtreecommitdiff
path: root/insane-shortest-paths.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'insane-shortest-paths.ipynb')
-rw-r--r--insane-shortest-paths.ipynb263
1 files changed, 263 insertions, 0 deletions
diff --git a/insane-shortest-paths.ipynb b/insane-shortest-paths.ipynb
new file mode 100644
index 0000000..72846c2
--- /dev/null
+++ b/insane-shortest-paths.ipynb
@@ -0,0 +1,263 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "execution_state": "idle",
+ "id": "86ce5f44-94f6-43b0-a0d1-091b8134ffb6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[set(), set(), {5, 6}, {4}, {3}, {2, 6}, {2, 5}]\n",
+ "[set(), {6}, set(), {4, 5, 6}, {3}, {3, 6}, {1, 3, 5}]\n",
+ "[set(), {4}, set(), {4, 5}, {1, 3}, {3, 6}, {5}]\n",
+ "[set(), {2, 6}, {1, 6}, {6}, set(), set(), {1, 2, 3}]\n",
+ "[set(), {3}, {3}, {1, 2, 5, 6}, {5}, {3, 4}, {3}]\n",
+ "[set(), {3, 6}, {4}, {1}, {2}, {6}, {1, 5}]\n",
+ "[set(), {2, 3}, {1, 3, 6}, {1, 2, 4}, {3}, set(), {2}]\n",
+ "[set(), {4}, set(), {4}, {1, 3, 5, 6}, {4}, {4}]\n",
+ "[set(), {3, 4, 5}, {6}, {1}, {1, 6}, {1}, {2, 4}]\n",
+ "[set(), {5, 6}, {6}, {6}, {5, 6}, {1, 4}, {1, 2, 3, 4}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# -*- coding: utf-8 -*-\n",
+ "\"\"\"how-tsp-should-be.ipynb\n",
+ "\n",
+ "Automatically generated by Colab.\n",
+ "\n",
+ "Original file is located at\n",
+ " https://colab.research.google.com/drive/1InE1iW8ARzndPpvqH_9y22s81sOiHxPs\n",
+ "\"\"\"\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import matplotlib as mpl\n",
+ "import matplotlib.pyplot as plt\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "\n",
+ "from math import sqrt\n",
+ "from collections import deque\n",
+ "import os\n",
+ "import random\n",
+ "import pickle\n",
+ "import ipdb\n",
+ "\n",
+ "# torch.manual_seed(30)\n",
+ "# random.seed(30)\n",
+ "torch.manual_seed(33)\n",
+ "random.seed(33)\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "# assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\"\n",
+ "\n",
+ "NVTXS = 6\n",
+ "MAXDIST = NVTXS+1\n",
+ "AVGDEG = 2\n",
+ "SEQLEN = NVTXS + 1\n",
+ "HIDDENDIM = 4*NVTXS+2\n",
+ "\n",
+ "# 0: ANSFLAG\n",
+ "# 1:NVTXS+1 NBRS\n",
+ "# NVTXS+1: 2*NVTXS+1 REACH\n",
+ "# 2*NVTXS+1: 3*NVTXS+1 SELF\n",
+ "# -1 NOTANSFLAG\n",
+ "\n",
+ "START_REACH = NVTXS+1\n",
+ "START_OUT = 2*NVTXS+1\n",
+ "START_SELF = 3*NVTXS+1\n",
+ "SRC_FLAG_IDX = START_SELF\n",
+ "SOURCE = 1\n",
+ "TARGET = 2\n",
+ "ANS_FLAG_IDX = 0\n",
+ "NOTANS_FLAG_IDX = -1\n",
+ "\n",
+ "def print_everything(data):\n",
+ " print(\"NBRS\")\n",
+ " print(data[0, 1:, 1:1+NVTXS])\n",
+ " print(\"REACH\")\n",
+ " print(data[0, 1:, START_REACH:START_REACH+NVTXS])\n",
+ " print(\"ANSFLAG\")\n",
+ " print(data[0, :, 0])\n",
+ " print(\"MORE FLAGS\")\n",
+ " print(data[0, :, -1])\n",
+ " print(\"SELF\")\n",
+ " print(data[0, 1:, START_SELF:START_SELF+NVTXS])\n",
+ " print(\"OUT\")\n",
+ " print(data[0, 0, START_OUT:START_OUT+NVTXS])\n",
+ "\n",
+ "\n",
+ "def random_graph():\n",
+ " data = torch.zeros((SEQLEN, HIDDENDIM))\n",
+ "\n",
+ " for i in range(1,NVTXS+1):\n",
+ " data[i, START_SELF-1+i] = 1\n",
+ "\n",
+ " adj_list = [set() for _ in range(SEQLEN)]\n",
+ " indices = [random.randint(1, NVTXS) for _ in range(AVGDEG * NVTXS)]\n",
+ " for i in range(0, len(indices), 2):\n",
+ " u = indices[i]\n",
+ " v = indices[i + 1]\n",
+ " if u != v:\n",
+ " data[v,u] = 1\n",
+ " data[u,v] = 1\n",
+ " data[v,NVTXS+u] = 1\n",
+ " data[u,NVTXS+v] = 1\n",
+ " adj_list[u].add(v)\n",
+ " adj_list[v].add(u)\n",
+ "\n",
+ " data[0, ANS_FLAG_IDX] = 1\n",
+ " data[1:, NOTANS_FLAG_IDX] = 1\n",
+ "\n",
+ " # TODO: this is kind of a hack\n",
+ " data[0, START_REACH:START_REACH+NVTXS] = 1\n",
+ " return data, adj_list\n",
+ "\n",
+ "\"\"\"\n",
+ "input: G, represented as an adjacency list\n",
+ "output: distance from SOURCE to TARGET\n",
+ "\"\"\"\n",
+ "def SSSP(G):\n",
+ " dist = [MAXDIST for _ in G]\n",
+ " dist[SOURCE] = 0\n",
+ " frontier = deque()\n",
+ " frontier.append(SOURCE)\n",
+ " while len(frontier) > 0:\n",
+ " vtx = frontier.popleft()\n",
+ " for x in G[vtx]:\n",
+ " if dist[x] == MAXDIST:\n",
+ " dist[x] = 1 + dist[vtx]\n",
+ " frontier.append(x)\n",
+ " if x == TARGET:\n",
+ " return dist[TARGET]\n",
+ " return MAXDIST\n",
+ "\n",
+ "def mkbatch(size):\n",
+ " graphs1 = []\n",
+ " distance1 = []\n",
+ "\n",
+ " for i in range(size):\n",
+ " data, adj_list = random_graph()\n",
+ " dist = SSSP(adj_list)\n",
+ " graphs1.append(data)\n",
+ " distance1.append(dist)\n",
+ "\n",
+ " print(adj_list)\n",
+ "\n",
+ " data = torch.stack(graphs1)\n",
+ " labels = torch.tensor(distance1, dtype=torch.float16)\n",
+ " return data, labels\n",
+ "\n",
+ "\"\"\"\n",
+ "TODO: WRAP EVERYTHING in nn.Parameter(torch.zeros((1, HIDDENDIM)))\n",
+ "and then do my perturbing parameters experiment\n",
+ "\n",
+ "TODO:\n",
+ " USE activation magic to bring everything back to the 0/1 realm instead of possibly being 0/2 valued\n",
+ "\"\"\"\n",
+ "\n",
+ "class SillyTransformer(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.most_KQVs = []\n",
+ " for head in range(1,NVTXS+1):\n",
+ " Q = torch.zeros((2, HIDDENDIM))\n",
+ " Q[0, START_REACH-1+head] = 1000\n",
+ " Q[1, NOTANS_FLAG_IDX] = 1\n",
+ "\n",
+ " K = torch.zeros((2, HIDDENDIM))\n",
+ " K[0, head] = 1\n",
+ " K[1, ANS_FLAG_IDX] = 200\n",
+ "\n",
+ " V = torch.zeros((NVTXS,HIDDENDIM))\n",
+ " for i in range(NVTXS):\n",
+ " V[i, START_SELF+i] = 1\n",
+ "\n",
+ " self.most_KQVs.append((K, Q, V))\n",
+ "\n",
+ " self.weird_KQVs = []\n",
+ " for layer in range(NVTXS):\n",
+ " K = torch.zeros((3, HIDDENDIM))\n",
+ " K[0, NOTANS_FLAG_IDX] = -1000\n",
+ " K[0, SRC_FLAG_IDX] = +1100\n",
+ " K[1, NOTANS_FLAG_IDX] = -1000\n",
+ " K[1, NVTXS+TARGET] = +1100\n",
+ " K[1, ANS_FLAG_IDX] = -1100\n",
+ " K[2, ANS_FLAG_IDX] = 10\n",
+ "\n",
+ " Q = torch.zeros((3, HIDDENDIM))\n",
+ " Q[:, ANS_FLAG_IDX] = 1\n",
+ "\n",
+ " V = torch.zeros((NVTXS, HIDDENDIM))\n",
+ " V[layer, SRC_FLAG_IDX] = 1\n",
+ "\n",
+ " self.weird_KQVs.append((K, Q, V))\n",
+ "\n",
+ " def forward(self, src):\n",
+ " for layer in range(NVTXS):\n",
+ " allKQVs = [self.weird_KQVs[layer]] + self.most_KQVs\n",
+ " head_outputs = []\n",
+ " for (K, Q, V) in allKQVs:\n",
+ " ksrc = torch.matmul(src, K.unsqueeze(0).transpose(-2, -1))\n",
+ " qsrc = torch.matmul(src, Q.unsqueeze(0).transpose(-2, -1))\n",
+ " vsrc = torch.matmul(src, V.unsqueeze(0).transpose(-2, -1))\n",
+ "\n",
+ " scores = torch.matmul(qsrc, ksrc.transpose(-2, -1))\n",
+ " attention_weights = torch.softmax(scores, dim=-1)\n",
+ " head_output = torch.matmul(attention_weights, vsrc)\n",
+ " head_outputs.append(head_output)\n",
+ "\n",
+ " new_reaches = sum(head_outputs[1:])\n",
+ " BSZ = new_reaches.shape[0]\n",
+ "\n",
+ " nodelta_nbrs = torch.zeros((BSZ, SEQLEN, NVTXS+1))\n",
+ " morepadlol = torch.zeros((BSZ, SEQLEN, 1+NVTXS))\n",
+ "\n",
+ " DIFF = torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)\n",
+ " src += torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)\n",
+ " src[:, :, START_REACH:START_REACH+NVTXS] = 2*torch.sigmoid(src[:,:, START_REACH:START_REACH+NVTXS]*1000)-1\n",
+ "\n",
+ " # print(\"SRC\")\n",
+ " # print_everything(src)\n",
+ "\n",
+ " canreach = src[:,0,START_OUT:START_OUT+NVTXS]\n",
+ " # __import__('ipdb').set_trace()\n",
+ " final_output = 1+torch.sum(1-canreach,dim=1)\n",
+ " return final_output\n",
+ "\n",
+ "model = SillyTransformer()\n",
+ "model.to(device)\n",
+ "\n",
+ "data, labels = mkbatch(10)\n",
+ "assert torch.all(model(data) == labels)\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}