diff options
-rw-r--r-- | client.py | 3 | ||||
-rw-r--r-- | server.py | 77 |
2 files changed, 49 insertions, 31 deletions
@@ -1,5 +1,6 @@ +import os import sys import xmlrpc.client -proxy = xmlrpc.client.ServerProxy("http://localhost:8000") +proxy = xmlrpc.client.ServerProxy(f"http://localhost:{os.environ.get("PORT", 8000)}") print("\n".join(proxy.search(sys.argv[1], sys.argv[2]))) @@ -66,35 +66,47 @@ def index(path, parent): print("Indexing", path, parent) s = os.stat(path) if os.path.isfile(path): - res = cur.execute("SELECT time, parent FROM idx WHERE id = ?", (s.st_ino,)) + res = cur.execute( + "SELECT time, parent, path FROM idx WHERE id = ?", (s.st_ino,) + ) db_vals = res.fetchall() - if len(db_vals) == 1 and (s.st_mtime, parent) == db_vals[0]: + if len(db_vals) == 1 and (s.st_mtime, parent, path) == db_vals[0]: # Already in DB, unmodified return - emb = None - type = mimetypes.guess_type(path)[0] - try: - if type is None and os.path.getsize(path) < 2**16: - with open(path) as f: - emb = model.embed_text(f.read()) - elif type.startswith("audio"): - emb = model.embed_audio(path) - elif type.startswith("image"): - emb = model.embed_image(path) - elif type.startswith("video") and os.path.getsize(path) < 2**25: - emb = model.embed_video(path) - except: - print(traceback.format_exc()) - - if emb is None: - # Might be in index but no longer valid - unindex(s.st_ino) - return - - cur.execute( - "INSERT OR REPLACE INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy()) - ) + if ( + len(db_vals) == 0 + or s.st_mtime != db_vals[0][1] + or len( + cur.execute("SELECT 1 FROM emb WHERE id = ?", (s.st_ino,)).fetchall() + ) + == 0 + ): + # Modified or not in emb + emb = None + type = mimetypes.guess_type(path)[0] + try: + if type is None and os.path.getsize(path) < 2**16: + with open(path) as f: + emb = model.embed_text(f.read()) + elif type.startswith("audio"): + emb = model.embed_audio(path) + elif type.startswith("image"): + emb = model.embed_image(path) + elif type.startswith("video") and os.path.getsize(path) < 2**25: + emb = model.embed_video(path) + except: + print(traceback.format_exc()) + + if emb is None: + # Might be in index but no longer valid + unindex(s.st_ino) + return + + cur.execute( + "INSERT OR REPLACE INTO emb VALUES (?, ?)", + (s.st_ino, emb.cpu().numpy()), + ) cur.execute( "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)", @@ -116,7 +128,8 @@ def index(path, parent): db_children_id = res.fetchall() for db_child_id in db_children_id: if db_child_id[0] not in children_id: - unindex(db_child_id[0]) + # Don't unemb because might be a move + unindex(db_child_id[0], False) # Index live children for child in children: @@ -125,14 +138,15 @@ def index(path, parent): index(os.path.join(path, child), s.st_ino) -def unindex(id): +def unindex(id, unemb=True): print("Unindexing", id) res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,)) db_children_id = res.fetchall() for db_child_id in db_children_id: unindex(db_child_id[0]) cur.execute("DELETE FROM idx WHERE id = ?", (id,)) - cur.execute("DELETE FROM emb WHERE id = ?", (id,)) + if unemb: + cur.execute("DELETE FROM emb WHERE id = ?", (id,)) con.commit() @@ -140,7 +154,7 @@ def search(text, limit): print("Search", text, limit) emb = model.embed_text(text).cpu().numpy() res = cur.execute( - "SELECT idx.path FROM emb LEFT JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance", + "SELECT idx.path FROM emb JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance", (emb, limit), ) return [i[0] for i in res.fetchall()] @@ -152,10 +166,13 @@ observer = Observer() event_handler = EventHandler() # Pretend that / is the parent of all indexed dirs index("/", 0) +# Clean up emb +cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)") +con.commit() observer.start() print("Starting RPC server") -server = SimpleXMLRPCServer(("localhost", 8000)) +server = SimpleXMLRPCServer(("localhost", int(os.environ.get("PORT", 8000)))) server.register_function(search) server.serve_forever() |