diff options
Diffstat (limited to 'server.py')
-rw-r--r-- | server.py | 35 |
1 files changed, 22 insertions, 13 deletions
@@ -3,6 +3,7 @@ import os import pathlib import sqlite3 import sys +import traceback from xmlrpc.server import SimpleXMLRPCServer import pillow_avif import sqlite_vec @@ -14,6 +15,7 @@ import model print("Connecting to DB") # Possible corruption? con = sqlite3.connect("index.db", check_same_thread=False) +con.execute("PRAGMA journal_mode=wal") con.enable_load_extension(True) sqlite_vec.load(con) cur = con.cursor() @@ -49,7 +51,10 @@ class EventHandler(FileSystemEventHandler): "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)", (s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path), ) - cur.execute("UPDATE idx SET path = replace(path, ?, ?)", (event.src_path, event.dest_path)) + cur.execute( + "UPDATE idx SET path = replace(path, ?, ?)", + (event.src_path, event.dest_path), + ) con.commit() @@ -65,30 +70,34 @@ def index(path, parent): emb = None type = mimetypes.guess_type(path)[0] - if type is None and os.path.getsize(path) < 2**16: - try: + try: + if type is None and os.path.getsize(path) < 2**16: with open(path) as f: emb = model.embed_text(f.read()) - except: - print("Not plain text, skipping") - elif type.startswith("audio"): - emb = model.embed_audio(path) - elif type.startswith("image"): - emb = model.embed_image(path) - elif type.startswith("video") and os.path.getsize(path) < 2**25: - emb = model.embed_video(path) + elif type.startswith("audio"): + emb = model.embed_audio(path) + elif type.startswith("image"): + emb = model.embed_image(path) + elif type.startswith("video") and os.path.getsize(path) < 2**25: + emb = model.embed_video(path) + except: + print(traceback.format_exc()) if emb is None: # Might be in index but no longer valid unindex(s.st_ino) return - cur.execute("INSERT OR REPLACE INTO emb VALUES (?, ?)", (s.st_ino, emb.numpy())) + cur.execute( + "INSERT OR REPLACE INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy()) + ) + cur.execute( "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)", (s.st_ino, parent, s.st_mtime, path), ) con.commit() + if os.path.isdir(path): if parent: children = os.listdir(path) @@ -124,7 +133,7 @@ def unindex(id): def search(text, limit): - emb = model.embed_text(text).numpy() + emb = model.embed_text(text).cpu().numpy() res = cur.execute( "SELECT idx.path FROM emb LEFT JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance", (emb, limit), |