aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client.py22
-rw-r--r--server.py56
2 files changed, 63 insertions, 15 deletions
diff --git a/client.py b/client.py
index 4249bcd..2b740cd 100644
--- a/client.py
+++ b/client.py
@@ -1,8 +1,28 @@
+from http.client import HTTPConnection
import os
+import socket
import sys
import xmlrpc.client
-proxy = xmlrpc.client.ServerProxy(f"http://localhost:{os.environ.get("PORT", 8000)}")
+
+class UnixStreamHTTPConnection(HTTPConnection):
+ def connect(self):
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.sock.connect(self.host)
+
+
+class UnixStreamTransport(xmlrpc.client.Transport, object):
+ def __init__(self, socket_path):
+ self.socket_path = socket_path
+ super().__init__()
+
+ def make_connection(self, host):
+ return UnixStreamHTTPConnection(self.socket_path)
+
+
+proxy = xmlrpc.client.ServerProxy(
+ "http://localhost", transport=UnixStreamTransport("search.sock")
+)
res = proxy.search(sys.argv[1], sys.argv[2])
print("\n".join(res))
diff --git a/server.py b/server.py
index 6dd226d..04a92e2 100644
--- a/server.py
+++ b/server.py
@@ -1,11 +1,12 @@
import mimetypes
import os
import pathlib
+import socketserver
import sqlite3
import sys
import threading
import traceback
-from xmlrpc.server import SimpleXMLRPCServer
+from xmlrpc.server import SimpleXMLRPCDispatcher, SimpleXMLRPCRequestHandler
import pillow_avif
import sqlite_vec
from watchdog.observers import Observer
@@ -25,6 +26,7 @@ cur.execute(
"CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024] distance_metric=cosine)"
)
con.commit()
+lock = threading.Lock()
def get_parent(path):
@@ -36,13 +38,10 @@ def get_parent(path):
class EventHandler(FileSystemEventHandler):
- lock = threading.Lock()
-
def dispatch(self, event):
- self.lock.acquire()
- print(event)
- super().dispatch(event)
- self.lock.release()
+ with lock:
+ print(event)
+ super().dispatch(event)
def on_created(self, event):
index(event.src_path, get_parent(event.src_path))
@@ -175,16 +174,45 @@ def search(text, limit):
print("Indexing files")
watchdirs = set(map(os.path.abspath, sys.argv[1:]))
observer = Observer()
-event_handler = EventHandler()
-# Pretend that / is the parent of all indexed dirs
-index("/", 0)
-# Clean up emb
-cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)")
-con.commit()
observer.start()
+event_handler = EventHandler()
+with lock:
+ # Pretend that / is the parent of all indexed dirs
+ index("/", 0)
+ # Clean up emb
+ cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)")
+ con.commit()
+
+
+class UnixStreamXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
+ disable_nagle_algorithm = False
+
+ def address_string(self):
+ return self.client_address
+
+
+class UnixStreamXMLRPCServer(socketserver.UnixStreamServer, SimpleXMLRPCDispatcher):
+ def __init__(
+ self,
+ addr,
+ log_requests=True,
+ allow_none=True,
+ encoding=None,
+ bind_and_activate=True,
+ use_builtin_types=True,
+ ):
+ self.logRequests = log_requests
+ SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, use_builtin_types)
+ socketserver.UnixStreamServer.__init__(
+ self,
+ addr,
+ UnixStreamXMLRPCRequestHandler,
+ bind_and_activate,
+ )
print("Starting RPC server")
-server = SimpleXMLRPCServer(("localhost", int(os.environ.get("PORT", 8000))))
+pathlib.Path("search.sock").unlink(missing_ok=True)
+server = UnixStreamXMLRPCServer("search.sock")
server.register_function(search)
server.serve_forever()