aboutsummaryrefslogtreecommitdiff
path: root/server.py
blob: 04a92e28b7c1bae8b72ab4df7a453e59e1b43aab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import mimetypes
import os
import pathlib
import socketserver
import sqlite3
import sys
import threading
import traceback
from xmlrpc.server import SimpleXMLRPCDispatcher, SimpleXMLRPCRequestHandler
import pillow_avif
import sqlite_vec
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import model


print("Connecting to DB")
con = sqlite3.connect("index.db", check_same_thread=False)
con.enable_load_extension(True)
sqlite_vec.load(con)
cur = con.cursor()
cur.execute(
    "CREATE TABLE IF NOT EXISTS idx (id INTEGER PRIMARY KEY, parent INTEGER, time INTEGER, path TEXT)"
)
cur.execute(
    "CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024] distance_metric=cosine)"
)
con.commit()
lock = threading.Lock()


def get_parent(path):
    if path in watchdirs:
        parent = "/"
    else:
        parent = pathlib.Path(path).parent
    return os.stat(parent).st_ino


class EventHandler(FileSystemEventHandler):
    def dispatch(self, event):
        with lock:
            print(event)
            super().dispatch(event)

    def on_created(self, event):
        index(event.src_path, get_parent(event.src_path))

    def on_modified(self, event):
        if not event.is_directory:
            self.on_created(event)

    def on_deleted(self, event):
        res = cur.execute("SELECT id FROM idx WHERE path = ?", (event.src_path,))
        ids = res.fetchall()
        if len(ids) == 1:
            unindex(ids[0][0])

    def on_moved(self, event):
        # inode doesn't change after move
        s = os.stat(event.dest_path)
        cur.execute(
            "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
            (s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path),
        )
        cur.execute(
            "UPDATE idx SET path = replace(path, ?, ?)",
            (event.src_path, event.dest_path),
        )
        con.commit()


def index(path, parent):
    if os.path.basename(path).startswith("."):
        # Skip hidden files
        return

    print("Indexing", path, parent)
    s = os.stat(path)
    if os.path.isfile(path):
        res = cur.execute(
            "SELECT time, parent, path FROM idx WHERE id = ?", (s.st_ino,)
        )
        db_vals = res.fetchall()
        if len(db_vals) == 1 and (s.st_mtime, parent, path) == db_vals[0]:
            # Already in DB, unmodified
            return

        if (
            len(db_vals) == 0
            or s.st_mtime != db_vals[0][1]
            or len(
                cur.execute("SELECT 1 FROM emb WHERE id = ?", (s.st_ino,)).fetchall()
            )
            == 0
        ):
            # Modified or not in emb
            emb = None
            type = mimetypes.guess_type(path)[0]
            try:
                if type is None and os.path.getsize(path) < 2**16:
                    with open(path) as f:
                        emb = model.embed_text(f.read())
                elif type.startswith("audio"):
                    emb = model.embed_audio(path)
                elif type.startswith("image"):
                    emb = model.embed_image(path)
                elif type.startswith("video") and os.path.getsize(path) < 2**25:
                    emb = model.embed_video(path)
            except:
                print(traceback.format_exc())

            if emb is None:
                # Might be in index but no longer valid
                unindex(s.st_ino)
                return

            # sqlite-vec doesn't support INSERT OR REPLACE?
            cur.execute("DELETE FROM emb WHERE id = ?", (s.st_ino,))
            cur.execute("INSERT INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy()))

    cur.execute(
        "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
        (s.st_ino, parent, s.st_mtime, path),
    )
    con.commit()

    if os.path.isdir(path):
        if parent:
            children = os.listdir(path)
        else:
            children = watchdirs

        # Find and unindex dead children
        children_id = set(
            os.stat(os.path.join(path, child)).st_ino for child in children
        )
        res = cur.execute("SELECT id FROM idx WHERE parent = ?", (s.st_ino,))
        db_children_id = res.fetchall()
        for db_child_id in db_children_id:
            if db_child_id[0] not in children_id:
                # Don't unemb because might be a move
                unindex(db_child_id[0], False)

        # Index live children
        for child in children:
            if not parent:
                observer.schedule(event_handler, child, recursive=True)
            index(os.path.join(path, child), s.st_ino)


def unindex(id, unemb=True):
    print("Unindexing", id)
    res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,))
    db_children_id = res.fetchall()
    for db_child_id in db_children_id:
        unindex(db_child_id[0])
    cur.execute("DELETE FROM idx WHERE id = ?", (id,))
    if unemb:
        cur.execute("DELETE FROM emb WHERE id = ?", (id,))
    con.commit()


def search(text, limit):
    print("Search", text, limit)
    emb = model.embed_text(text).cpu().numpy()
    res = cur.execute(
        "SELECT idx.path FROM emb JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
        (emb, limit),
    )
    return [i[0] for i in res.fetchall()]


print("Indexing files")
watchdirs = set(map(os.path.abspath, sys.argv[1:]))
observer = Observer()
observer.start()
event_handler = EventHandler()
with lock:
    # Pretend that / is the parent of all indexed dirs
    index("/", 0)
    # Clean up emb
    cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)")
    con.commit()


class UnixStreamXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
    disable_nagle_algorithm = False

    def address_string(self):
        return self.client_address


class UnixStreamXMLRPCServer(socketserver.UnixStreamServer, SimpleXMLRPCDispatcher):
    def __init__(
        self,
        addr,
        log_requests=True,
        allow_none=True,
        encoding=None,
        bind_and_activate=True,
        use_builtin_types=True,
    ):
        self.logRequests = log_requests
        SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, use_builtin_types)
        socketserver.UnixStreamServer.__init__(
            self,
            addr,
            UnixStreamXMLRPCRequestHandler,
            bind_and_activate,
        )


print("Starting RPC server")
pathlib.Path("search.sock").unlink(missing_ok=True)
server = UnixStreamXMLRPCServer("search.sock")
server.register_function(search)
server.serve_forever()