import mimetypes import os import pathlib import socketserver import sqlite3 import sys import threading import traceback from xmlrpc.server import SimpleXMLRPCDispatcher, SimpleXMLRPCRequestHandler import pillow_avif import sqlite_vec from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler, FileOpenedEvent import model print("Connecting to DB") con = sqlite3.connect("index.db", check_same_thread=False) con.enable_load_extension(True) sqlite_vec.load(con) cur = con.cursor() cur.execute( "CREATE TABLE IF NOT EXISTS idx (id INTEGER PRIMARY KEY, parent INTEGER, time INTEGER, path TEXT)" ) cur.execute( "CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024] distance_metric=cosine)" ) con.commit() lock = threading.Lock() def get_parent(path): if path in watchdirs: parent = "/" else: parent = pathlib.Path(path).parent return os.stat(parent).st_ino class EventHandler(FileSystemEventHandler): def dispatch(self, event): if not isinstance(event, FileOpenedEvent): with lock: print(event) super().dispatch(event) def on_created(self, event): index(event.src_path, get_parent(event.src_path)) def on_modified(self, event): if not event.is_directory: self.on_created(event) def on_deleted(self, event): res = cur.execute("SELECT id FROM idx WHERE path = ?", (event.src_path,)) ids = res.fetchall() if len(ids) == 1: unindex(ids[0][0]) def on_moved(self, event): # inode doesn't change after move s = os.stat(event.dest_path) cur.execute( "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)", (s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path), ) cur.execute( "UPDATE idx SET path = replace(path, ?, ?)", (event.src_path, event.dest_path), ) con.commit() def index(path, parent): if not os.path.exists(path) or os.path.basename(path).startswith("."): # Skip nonexistent or hidden files return print("Indexing", path, parent) s = os.stat(path) if os.path.isfile(path): res = cur.execute( "SELECT time, parent, path FROM idx WHERE id = ?", (s.st_ino,) ) db_vals = res.fetchall() if len(db_vals) == 1 and (s.st_mtime, parent, path) == db_vals[0]: # Already in DB, unmodified return if ( len(db_vals) == 0 or s.st_mtime != db_vals[0][1] or len( cur.execute("SELECT 1 FROM emb WHERE id = ?", (s.st_ino,)).fetchall() ) == 0 ): # Modified or not in emb emb = None type = mimetypes.guess_type(path)[0] if isinstance(type, str): try: if type.startswith("audio"): emb = model.embed_audio(path) elif type.startswith("image"): emb = model.embed_image(path) elif type.startswith("video") and os.path.getsize(path) < 2**25: emb = model.embed_video(path) except: print(traceback.format_exc()) if emb is None: # Might be in index but no longer valid unindex(s.st_ino) return # sqlite-vec doesn't support INSERT OR REPLACE? cur.execute("DELETE FROM emb WHERE id = ?", (s.st_ino,)) cur.execute("INSERT INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy())) cur.execute( "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)", (s.st_ino, parent, s.st_mtime, path), ) con.commit() if os.path.isdir(path): if parent: children = os.listdir(path) else: children = watchdirs # Find and unindex dead children children_id = set( os.stat(os.path.join(path, child)).st_ino for child in children ) res = cur.execute("SELECT id FROM idx WHERE parent = ?", (s.st_ino,)) db_children_id = res.fetchall() for db_child_id in db_children_id: if db_child_id[0] not in children_id: # Don't unemb because might be a move unindex(db_child_id[0], False) # Index live children for child in children: if not parent: observer.schedule(event_handler, child, recursive=True) index(os.path.join(path, child), s.st_ino) def unindex(id, unemb=True): print("Unindexing", id) res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,)) db_children_id = res.fetchall() for db_child_id in db_children_id: unindex(db_child_id[0]) cur.execute("DELETE FROM idx WHERE id = ?", (id,)) if unemb: cur.execute("DELETE FROM emb WHERE id = ?", (id,)) con.commit() def search(text, limit): print("Search", text, limit) emb = model.embed_text(text).cpu().numpy() res = cur.execute( "SELECT idx.path FROM emb JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance", (emb, limit), ) return [i[0] for i in res.fetchall()] print("Indexing files") watchdirs = set(map(os.path.abspath, sys.argv[1:])) observer = Observer() observer.start() event_handler = EventHandler() with lock: # Pretend that / is the parent of all indexed dirs index("/", 0) # Clean up emb cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)") con.commit() class UnixStreamXMLRPCRequestHandler(SimpleXMLRPCRequestHandler): disable_nagle_algorithm = False def address_string(self): return self.client_address class UnixStreamXMLRPCServer(socketserver.UnixStreamServer, SimpleXMLRPCDispatcher): def __init__( self, addr, log_requests=True, allow_none=True, encoding=None, bind_and_activate=True, use_builtin_types=True, ): self.logRequests = log_requests SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, use_builtin_types) socketserver.UnixStreamServer.__init__( self, addr, UnixStreamXMLRPCRequestHandler, bind_and_activate, ) print("Starting RPC server") pathlib.Path("search.sock").unlink(missing_ok=True) server = UnixStreamXMLRPCServer("search.sock") server.register_function(search) server.serve_forever()