aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2024-07-26 23:03:29 -0500
committerAnthony Wang2024-07-26 23:03:29 -0500
commit088075b13330be9730e339d9eaf01e8ab468d6c3 (patch)
treeb90bbca4b68c8bbe8ca9c2dc69a6bab2df5816e9
Initial commit
-rw-r--r--.gitmodules3
m---------ImageBind0
-rw-r--r--README.md17
-rw-r--r--client.py5
-rw-r--r--model.py38
-rw-r--r--requirements.txt4
-rw-r--r--server.py147
7 files changed, 214 insertions, 0 deletions
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..43dfa86
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "ImageBind"]
+ path = ImageBind
+ url = https://github.com/facebookresearch/ImageBind.git
diff --git a/ImageBind b/ImageBind
new file mode 160000
+Subproject 3fcf5c9039de97f6ff5528ee4a9dce903c5979b
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..443cc5b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,17 @@
+# search
+
+Semantic file search using ImageBind and sqlite-vec
+
+## Installation
+
+```
+pip install -r requirements.txt
+```
+
+If you don't need PyTorch with GPU support, first run `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu`. During runtime, if you get the error `ModuleNotFoundError: No module named 'torchvision.transforms.functional_tensor'`, change `functional_tensor` to `functional`.
+
+## Usage
+
+This program uses a client-server architecture to watch directories with inotify and keep the model loaded in memory. It takes around 15 seconds to load the model so it wouldn't be great if every query had to wait on that.
+
+Run `python server.py DIRS_TO_INDEX` to start the server. Then, run `python client.py SEARCH_TEXT NUM_RESULTS` to get a list of the most similar files.
diff --git a/client.py b/client.py
new file mode 100644
index 0000000..2d7ea97
--- /dev/null
+++ b/client.py
@@ -0,0 +1,5 @@
+import sys
+import xmlrpc.client
+
+proxy = xmlrpc.client.ServerProxy("http://localhost:8000")
+print('\n'.join(proxy.search(sys.argv[1], sys.argv[2])))
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..d2e765d
--- /dev/null
+++ b/model.py
@@ -0,0 +1,38 @@
+import torch
+from imagebind import data
+from imagebind.models import imagebind_model
+from imagebind.models.imagebind_model import ModalityType
+
+device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+print("Loading model")
+model = imagebind_model.imagebind_huge(pretrained=True)
+model.eval()
+model.to(device)
+
+
+def embed(mode, data):
+ with torch.no_grad():
+ return model({mode: data})[mode][0]
+
+
+def embed_text(text):
+ return embed(ModalityType.TEXT, data.load_and_transform_text([text], device))
+
+
+def embed_audio(audio_path):
+ return embed(
+ ModalityType.AUDIO, data.load_and_transform_audio_data([audio_path], device)
+ )
+
+
+def embed_image(image_path):
+ return embed(
+ ModalityType.VISION, data.load_and_transform_vision_data([image_path], device)
+ )
+
+
+def embed_video(video_path):
+ return embed(
+ ModalityType.VISION, data.load_and_transform_video_data([video_path], device)
+ )
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..7a559de
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+file:ImageBind
+pillow-avif-plugin==1.4.6
+watchdog==4.0.1
+sqlite-vec==0.0.1a37
diff --git a/server.py b/server.py
new file mode 100644
index 0000000..4c3046e
--- /dev/null
+++ b/server.py
@@ -0,0 +1,147 @@
+import mimetypes
+import os
+import pathlib
+import sqlite3
+import sys
+from xmlrpc.server import SimpleXMLRPCServer
+import pillow_avif
+import sqlite_vec
+from watchdog.observers import Observer
+from watchdog.events import FileSystemEventHandler
+import model
+
+
+print("Connecting to DB")
+# Possible corruption?
+con = sqlite3.connect("index.db", check_same_thread=False)
+con.enable_load_extension(True)
+sqlite_vec.load(con)
+cur = con.cursor()
+cur.execute(
+ "CREATE TABLE IF NOT EXISTS idx (id INTEGER PRIMARY KEY, parent INTEGER, time INTEGER, path TEXT)"
+)
+cur.execute(
+ "CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024])"
+)
+con.commit()
+
+
+def get_parent(path):
+ if path in watchdirs:
+ return 0
+ return os.stat(pathlib.Path(path).parent).st_ino
+
+
+class EventHandler(FileSystemEventHandler):
+ def on_created(self, event):
+ index(event.src_path, get_parent(event.src_path))
+
+ def on_modified(self, event):
+ self.on_created(event)
+
+ def on_deleted(self, event):
+ unindex(os.stat(event.src_path).st_ino)
+
+ def on_moved(self, event):
+ # inode doesn't change after move
+ s = os.stat(event.dest_path)
+ cur.execute(
+ "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
+ (s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path),
+ )
+ con.commit()
+
+
+def index(path, parent):
+ print("Indexing", path, parent)
+ s = os.stat(path)
+ if os.path.isfile(path):
+ res = cur.execute("SELECT time, parent FROM idx WHERE id = ?", (s.st_ino,))
+ db_vals = res.fetchall()
+ if len(db_vals) == 1 and (s.st_mtime, parent) == db_vals[0]:
+ # Already in DB, unmodified
+ return
+
+ emb = None
+ type = mimetypes.guess_type(path)[0]
+ if type is None and os.path.getsize(path) < 2**16:
+ try:
+ with open(path) as f:
+ emb = model.embed_text(f.read())
+ except:
+ print("Not plain text, skipping")
+ elif type.startswith("audio"):
+ emb = model.embed_audio(path)
+ elif type.startswith("image"):
+ emb = model.embed_image(path)
+ elif type.startswith("video") and os.path.getsize(path) < 2**25:
+ emb = model.embed_video(path)
+
+ if emb is None:
+ # Might be in index but no longer valid
+ unindex(s.st_ino)
+ return
+
+ cur.execute("INSERT OR REPLACE INTO emb VALUES (?, ?)", (s.st_ino, emb.numpy()))
+ else:
+ if parent:
+ children = os.listdir(path)
+ else:
+ children = watchdirs
+
+ # Find and unindex dead children
+ children_id = set(
+ os.stat(os.path.join(path, child)).st_ino for child in children
+ )
+ res = cur.execute("SELECT id FROM idx WHERE parent = ?", (s.st_ino,))
+ db_children_id = res.fetchall()
+ for db_child_id in db_children_id:
+ if db_child_id[0] not in children_id:
+ unindex(db_child_id[0])
+
+ # Index live children
+ for child in children:
+ if not parent:
+ observer.schedule(event_handler, child, recursive=True)
+ index(os.path.join(path, child), s.st_ino)
+
+ cur.execute(
+ "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
+ (s.st_ino, parent, s.st_mtime, path),
+ )
+ con.commit()
+
+
+def unindex(id):
+ print("Unindexing", id)
+ res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,))
+ db_children_id = res.fetchall()
+ for db_child_id in db_children_id:
+ unindex(db_child_id[0])
+ cur.execute("DELETE FROM idx WHERE id = ?", (id,))
+ cur.execute("DELETE FROM emb WHERE id = ?", (id,))
+ con.commit()
+
+
+def search(text, limit):
+ emb = model.embed_text(text).numpy()
+ res = cur.execute(
+ "SELECT idx.path FROM emb LEFT JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
+ (emb, limit),
+ )
+ return [i[0] for i in res.fetchall()]
+
+
+print("Indexing files")
+watchdirs = set(sys.argv[1:])
+observer = Observer()
+event_handler = EventHandler()
+# Pretend that / is the parent of all indexed dirs
+index("/", 0)
+observer.start()
+
+
+print("Starting RPC server")
+server = SimpleXMLRPCServer(("localhost", 8000))
+server.register_function(search)
+server.serve_forever()