aboutsummaryrefslogtreecommitdiff
path: root/server.py
blob: d993a302e547bd7a16cc025e84eedd4c2784a1e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import mimetypes
import os
import pathlib
import sqlite3
import sys
import traceback
from xmlrpc.server import SimpleXMLRPCServer
import pillow_avif
import sqlite_vec
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import model


print("Connecting to DB")
# Possible corruption?
con = sqlite3.connect("index.db", check_same_thread=False)
con.execute("PRAGMA journal_mode=wal")
con.enable_load_extension(True)
sqlite_vec.load(con)
cur = con.cursor()
cur.execute(
    "CREATE TABLE IF NOT EXISTS idx (id INTEGER PRIMARY KEY, parent INTEGER, time INTEGER, path TEXT)"
)
cur.execute(
    "CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024])"
)
con.commit()


def get_parent(path):
    if path in watchdirs:
        return 0
    return os.stat(pathlib.Path(path).parent).st_ino


class EventHandler(FileSystemEventHandler):
    def on_created(self, event):
        index(event.src_path, get_parent(event.src_path))

    def on_modified(self, event):
        self.on_created(event)

    def on_deleted(self, event):
        unindex(os.stat(event.src_path).st_ino)

    def on_moved(self, event):
        # inode doesn't change after move
        s = os.stat(event.dest_path)
        cur.execute(
            "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
            (s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path),
        )
        cur.execute(
            "UPDATE idx SET path = replace(path, ?, ?)",
            (event.src_path, event.dest_path),
        )
        con.commit()


def index(path, parent):
    print("Indexing", path, parent)
    s = os.stat(path)
    if os.path.isfile(path):
        res = cur.execute("SELECT time, parent FROM idx WHERE id = ?", (s.st_ino,))
        db_vals = res.fetchall()
        if len(db_vals) == 1 and (s.st_mtime, parent) == db_vals[0]:
            # Already in DB, unmodified
            return

        emb = None
        type = mimetypes.guess_type(path)[0]
        try:
            if type is None and os.path.getsize(path) < 2**16:
                with open(path) as f:
                    emb = model.embed_text(f.read())
            elif type.startswith("audio"):
                emb = model.embed_audio(path)
            elif type.startswith("image"):
                emb = model.embed_image(path)
            elif type.startswith("video") and os.path.getsize(path) < 2**25:
                emb = model.embed_video(path)
        except:
            print(traceback.format_exc())

        if emb is None:
            # Might be in index but no longer valid
            unindex(s.st_ino)
            return

        cur.execute(
            "INSERT OR REPLACE INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy())
        )

    cur.execute(
        "INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
        (s.st_ino, parent, s.st_mtime, path),
    )
    con.commit()

    if os.path.isdir(path):
        if parent:
            children = os.listdir(path)
        else:
            children = watchdirs

        # Find and unindex dead children
        children_id = set(
            os.stat(os.path.join(path, child)).st_ino for child in children
        )
        res = cur.execute("SELECT id FROM idx WHERE parent = ?", (s.st_ino,))
        db_children_id = res.fetchall()
        for db_child_id in db_children_id:
            if db_child_id[0] not in children_id:
                unindex(db_child_id[0])

        # Index live children
        for child in children:
            if not parent:
                observer.schedule(event_handler, child, recursive=True)
            index(os.path.join(path, child), s.st_ino)


def unindex(id):
    print("Unindexing", id)
    res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,))
    db_children_id = res.fetchall()
    for db_child_id in db_children_id:
        unindex(db_child_id[0])
    cur.execute("DELETE FROM idx WHERE id = ?", (id,))
    cur.execute("DELETE FROM emb WHERE id = ?", (id,))
    con.commit()


def search(text, limit):
    emb = model.embed_text(text).cpu().numpy()
    res = cur.execute(
        "SELECT idx.path FROM emb LEFT JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
        (emb, limit),
    )
    return [i[0] for i in res.fetchall()]


print("Indexing files")
watchdirs = set(map(os.path.abspath, sys.argv[1:]))
observer = Observer()
event_handler = EventHandler()
# Pretend that / is the parent of all indexed dirs
index("/", 0)
observer.start()


print("Starting RPC server")
server = SimpleXMLRPCServer(("localhost", 8000))
server.register_function(search)
server.serve_forever()