aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client.py3
-rw-r--r--server.py77
2 files changed, 49 insertions, 31 deletions
diff --git a/client.py b/client.py
index 180335f..8c6cbec 100644
--- a/client.py
+++ b/client.py
@@ -1,5 +1,6 @@
+import os
import sys
import xmlrpc.client
-proxy = xmlrpc.client.ServerProxy("http://localhost:8000")
+proxy = xmlrpc.client.ServerProxy(f"http://localhost:{os.environ.get("PORT", 8000)}")
print("\n".join(proxy.search(sys.argv[1], sys.argv[2])))
diff --git a/server.py b/server.py
index 6a04076..dafc26b 100644
--- a/server.py
+++ b/server.py
@@ -66,35 +66,47 @@ def index(path, parent):
print("Indexing", path, parent)
s = os.stat(path)
if os.path.isfile(path):
- res = cur.execute("SELECT time, parent FROM idx WHERE id = ?", (s.st_ino,))
+ res = cur.execute(
+ "SELECT time, parent, path FROM idx WHERE id = ?", (s.st_ino,)
+ )
db_vals = res.fetchall()
- if len(db_vals) == 1 and (s.st_mtime, parent) == db_vals[0]:
+ if len(db_vals) == 1 and (s.st_mtime, parent, path) == db_vals[0]:
# Already in DB, unmodified
return
- emb = None
- type = mimetypes.guess_type(path)[0]
- try:
- if type is None and os.path.getsize(path) < 2**16:
- with open(path) as f:
- emb = model.embed_text(f.read())
- elif type.startswith("audio"):
- emb = model.embed_audio(path)
- elif type.startswith("image"):
- emb = model.embed_image(path)
- elif type.startswith("video") and os.path.getsize(path) < 2**25:
- emb = model.embed_video(path)
- except:
- print(traceback.format_exc())
-
- if emb is None:
- # Might be in index but no longer valid
- unindex(s.st_ino)
- return
-
- cur.execute(
- "INSERT OR REPLACE INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy())
- )
+ if (
+ len(db_vals) == 0
+ or s.st_mtime != db_vals[0][1]
+ or len(
+ cur.execute("SELECT 1 FROM emb WHERE id = ?", (s.st_ino,)).fetchall()
+ )
+ == 0
+ ):
+ # Modified or not in emb
+ emb = None
+ type = mimetypes.guess_type(path)[0]
+ try:
+ if type is None and os.path.getsize(path) < 2**16:
+ with open(path) as f:
+ emb = model.embed_text(f.read())
+ elif type.startswith("audio"):
+ emb = model.embed_audio(path)
+ elif type.startswith("image"):
+ emb = model.embed_image(path)
+ elif type.startswith("video") and os.path.getsize(path) < 2**25:
+ emb = model.embed_video(path)
+ except:
+ print(traceback.format_exc())
+
+ if emb is None:
+ # Might be in index but no longer valid
+ unindex(s.st_ino)
+ return
+
+ cur.execute(
+ "INSERT OR REPLACE INTO emb VALUES (?, ?)",
+ (s.st_ino, emb.cpu().numpy()),
+ )
cur.execute(
"INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
@@ -116,7 +128,8 @@ def index(path, parent):
db_children_id = res.fetchall()
for db_child_id in db_children_id:
if db_child_id[0] not in children_id:
- unindex(db_child_id[0])
+ # Don't unemb because might be a move
+ unindex(db_child_id[0], False)
# Index live children
for child in children:
@@ -125,14 +138,15 @@ def index(path, parent):
index(os.path.join(path, child), s.st_ino)
-def unindex(id):
+def unindex(id, unemb=True):
print("Unindexing", id)
res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,))
db_children_id = res.fetchall()
for db_child_id in db_children_id:
unindex(db_child_id[0])
cur.execute("DELETE FROM idx WHERE id = ?", (id,))
- cur.execute("DELETE FROM emb WHERE id = ?", (id,))
+ if unemb:
+ cur.execute("DELETE FROM emb WHERE id = ?", (id,))
con.commit()
@@ -140,7 +154,7 @@ def search(text, limit):
print("Search", text, limit)
emb = model.embed_text(text).cpu().numpy()
res = cur.execute(
- "SELECT idx.path FROM emb LEFT JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
+ "SELECT idx.path FROM emb JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
(emb, limit),
)
return [i[0] for i in res.fetchall()]
@@ -152,10 +166,13 @@ observer = Observer()
event_handler = EventHandler()
# Pretend that / is the parent of all indexed dirs
index("/", 0)
+# Clean up emb
+cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)")
+con.commit()
observer.start()
print("Starting RPC server")
-server = SimpleXMLRPCServer(("localhost", 8000))
+server = SimpleXMLRPCServer(("localhost", int(os.environ.get("PORT", 8000))))
server.register_function(search)
server.serve_forever()