diff options
-rw-r--r-- | client/main.go | 35 | ||||
-rw-r--r-- | server/dht.go | 126 | ||||
-rw-r--r-- | server/main.go | 45 | ||||
-rw-r--r-- | server/storage.go | 167 | ||||
-rwxr-xr-x | server/test.sh | 7 | ||||
-rw-r--r-- | server/user.go | 77 |
6 files changed, 341 insertions, 116 deletions
diff --git a/client/main.go b/client/main.go index 5397384..150e160 100644 --- a/client/main.go +++ b/client/main.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/ed25519" "encoding/base64" + "encoding/binary" "flag" "fmt" "net/http" @@ -19,10 +20,25 @@ var servers []string // Post new server list to DHT func dhtPost(s string) { - message := []byte(fmt.Sprint(time.Now().Unix()) + "\n" + strings.Join(servers, "\n")) - fmt.Print(message) - message = append(message, ed25519.Sign(privKey, message)...) - http.Post(s+"/dht/"+id, "application/octet-stream", bytes.NewBuffer(message)) + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.LittleEndian, time.Now().Unix()) + if err != nil { + panic(err) + } + _, err = buf.WriteString(strings.Join(servers, "\n")) + if err != nil { + panic(err) + } + var message []byte + _, err = buf.Read(message) + if err != nil { + panic(err) + } + _, err = buf.Write(ed25519.Sign(privKey, message)) + if err != nil { + panic(err) + } + fmt.Println(http.Post(s+"/dht/"+id, "application/octet-stream", buf)) } func main() { @@ -55,19 +71,22 @@ func main() { if err != nil { panic(err) } - pubKey := ed25519.PublicKey(pubKeyBytes) + pubKey = ed25519.PublicKey(pubKeyBytes) privKeyBytes, err := os.ReadFile("privkey") if err != nil { panic(err) } - privKey := ed25519.PublicKey(privKeyBytes) + privKey = ed25519.PrivateKey(privKeyBytes) serversBytes, err := os.ReadFile("servers") if err != nil { panic(err) } id = base64.RawURLEncoding.EncodeToString(pubKey) - servers := strings.Split(string(serversBytes), "\n") - fmt.Println(pubKey, privKey, servers) + servers = strings.Split(string(serversBytes), "\n") + if servers[0] == "" { + servers = servers[1:] + } + fmt.Println(id, servers) if flag.Arg(0) == "add" { // Add server diff --git a/server/dht.go b/server/dht.go index 6590199..7327676 100644 --- a/server/dht.go +++ b/server/dht.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "crypto/sha256" "encoding/base64" + "encoding/binary" "fmt" "io" "log" @@ -29,9 +30,19 @@ func sha256sum(s string) string { return base64.RawURLEncoding.EncodeToString(b[:]) } +// Find the position of a key in the DHT +func keyPos(key string) int { + keyPos := sort.SearchStrings(peerHashes, sha256sum(key)) + if keyPos < myPos { + keyPos += len(peerHashes) + } + return keyPos +} + // Try to peer with another server func addPeer(peer string) error { peerHash := sha256sum(peer) + // Check if already peered mu.Lock() _, ok := hashToDomain[peerHash] @@ -44,7 +55,6 @@ func addPeer(peer string) error { mu.Unlock() // Try request to peer - log.Printf("%s trying to peer with %s", me, peer) resp, err := http.Get(peer + "/peer?peer=" + me) if err != nil { // Request failed, delete peer @@ -54,25 +64,25 @@ func addPeer(peer string) error { return err } - log.Printf("%s successfully peered with %s", me, peer) + // Add peer mu.Lock() i := sort.SearchStrings(peerHashes, peerHash) peerHashes = append(peerHashes, "") copy(peerHashes[i+1:], peerHashes[i:]) peerHashes[i] = peerHash - myPos = sort.SearchStrings(peerHashes, me) - + myPos = sort.SearchStrings(peerHashes, myHash) // Distribute keys to new server for id, user := range users { phase := time.Now().Unix() / 600 - if keyPos(id + "\n" + fmt.Sprint(phase))-myPos < 5 { - go http.Post(peer+"/dht/"+id+"?phase="+fmt.Sprint(phase)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + if keyPos(id+"\n"+fmt.Sprint(phase))-myPos < 5 { + go http.Post(peer+"/dht/"+id+"?phase="+fmt.Sprint(phase)+"&direct", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) } - if keyPos(id + "\n" + fmt.Sprint(phase+1))-myPos < 5 { - go http.Post(peer+"/dht/"+id+"?phase="+fmt.Sprint(phase+1)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + if keyPos(id+"\n"+fmt.Sprint(phase+1))-myPos < 5 { + go http.Post(peer+"/dht/"+id+"?phase="+fmt.Sprint(phase+1)+"&direct", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) } } mu.Unlock() + log.Printf("%s successfully peered with %s", me, peer) // Read response body body, err := io.ReadAll(resp.Body) @@ -95,39 +105,27 @@ func peerHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) return } + mu.Lock() for _, p := range hashToDomain { fmt.Fprintf(w, "%s\n", p) } + mu.Unlock() go addPeer(peer) } -// Find the position of a key in the DHT -func keyPos(key string) int { - keyPos := sort.SearchStrings(peerHashes, sha256sum(key)) - if keyPos < myPos { - keyPos += len(peerHashes) - } - return keyPos -} - // Get the timestamp of this val -func timestamp(val []byte) int { - if len(val) < ed25519.SignatureSize { - return 0 - } - message := string(val[:len(val)-ed25519.SignatureSize]) - timestamp, err := strconv.Atoi(strings.Split(message, "\n")[0]) - if err != nil { +func timestamp(val []byte) int64 { + if len(val) < 8+ed25519.SignatureSize { return 0 } - return timestamp + ret, _ := binary.Varint(val[:8]) + return ret } // Get the value for a key from the DHT -func dhtGet(key, direct string) []byte { +func dhtGet(key string, direct bool) []byte { phase := fmt.Sprint(time.Now().Unix() / 600) - keyPos := keyPos(key + "\n" + phase) - if direct != "" && keyPos-myPos < 5 { + if direct { // Directly read from kvstore mu.Lock() val, ok := kvstore[key+"\n"+phase] @@ -139,12 +137,14 @@ func dhtGet(key, direct string) []byte { } // Contact 5 servers that store this key-value pair - var mu sync.Mutex + var mu2 sync.Mutex var wg sync.WaitGroup var latest []byte + mu.Lock() + keyPos := keyPos(key + "\n" + phase) for i := 0; i < 5 && i < len(peerHashes); i++ { wg.Add(1) - j := hashToDomain[peerHashes[(keyPos+i)%len(peerHashes)]] + j := hashToDomain[peerHashes[(keyPos-i+len(peerHashes))%len(peerHashes)]] go func() { defer wg.Done() resp, err := http.Get(j + "/dht/" + key + "?direct") @@ -159,20 +159,21 @@ func dhtGet(key, direct string) []byte { if err != nil { return } - mu.Lock() + mu2.Lock() if latest == nil || timestamp(val) > timestamp(latest) { latest = val } - mu.Unlock() + mu2.Unlock() }() } + mu.Unlock() // Wait for all to finish or time out wg.Wait() return latest } // Post a key-value pair into the DHT -func dhtPost(key, phase, direct string, val []byte) error { +func dhtPost(key, phase string, direct bool, val []byte) error { err := verify(key, val) if err != nil { return err @@ -180,21 +181,8 @@ func dhtPost(key, phase, direct string, val []byte) error { if phase == "" { phase = fmt.Sprint(time.Now().Unix() / 600) } - user, ok := users[key] - if ok { - curPhase, err := strconv.Atoi(phase) - if err != nil { - return err - } - nextPhase := time.Now().Unix()/600 + 1 - if int64(curPhase) < nextPhase && user.phase < nextPhase { - user.phase = nextPhase - go dhtPost(key, fmt.Sprint(nextPhase), "", val) - } - } - keyPos := keyPos(key + "\n" + phase) - if direct != "" && keyPos-myPos < 5 { + if direct { // Directly write to kvstore mu.Lock() curVal, ok := kvstore[key+"\n"+phase] @@ -205,11 +193,31 @@ func dhtPost(key, phase, direct string, val []byte) error { return nil } + // Post the key-value pair to the next phase if necessary + mu.Lock() + curPhase, err := strconv.Atoi(phase) + if err != nil { + return err + } + nextPhase := time.Now().Unix()/600 + 1 + if int64(curPhase) < nextPhase { + user, ok := users[key] + if ok && user.phase < nextPhase { + user.phase = nextPhase + persist(key) + } + go dhtPost(key, fmt.Sprint(nextPhase), false, val) + } + keyPos := keyPos(key + "\n" + phase) + mu.Unlock() + // Contact 5 servers that store this key-value pair + mu.Lock() for i := 0; i < 5 && i < len(peerHashes); i++ { - j := hashToDomain[peerHashes[(keyPos+i)%len(peerHashes)]] - go http.Post(j+"/dht/"+key+"?phase="+phase+"&direct=true", "application/octet-stream", bytes.NewBuffer(val)) + j := hashToDomain[peerHashes[(keyPos-i+len(peerHashes))%len(peerHashes)]] + go http.Post(j+"/dht/"+key+"?phase="+phase+"&direct", "application/octet-stream", bytes.NewBuffer(val)) } + mu.Unlock() return nil } @@ -218,7 +226,7 @@ func dhtHandler(w http.ResponseWriter, r *http.Request) { key := r.URL.Path[5:] r.ParseForm() if r.Method == "GET" { - val := dhtGet(key, r.Form.Get("direct")) + val := dhtGet(key, r.Form.Has("direct")) if val == nil { w.WriteHeader(http.StatusNotFound) return @@ -226,7 +234,7 @@ func dhtHandler(w http.ResponseWriter, r *http.Request) { w.Write(val) } else if r.Method == "POST" { val, err := io.ReadAll(r.Body) - if err != nil || dhtPost(key, r.Form.Get("phase"), r.Form.Get("direct"), val) != nil { + if err != nil || dhtPost(key, r.Form.Get("phase"), r.Form.Has("direct"), val) != nil { w.WriteHeader(http.StatusInternalServerError) return } @@ -242,11 +250,13 @@ func cleanPeers() { mu.Unlock() _, err := http.Get(peer) if err != nil { + log.Printf("Removing peer %s", peer) + // Bad response, so remove peer mu.Lock() i := sort.SearchStrings(peerHashes, sha256sum(peer)) peerHashes = append(peerHashes[:i], peerHashes[i+1:]...) - myPos = sort.SearchStrings(peerHashes, me) + myPos = sort.SearchStrings(peerHashes, myHash) // Distribute keys on this server to other servers if len(peerHashes) >= 5 { @@ -254,13 +264,13 @@ func cleanPeers() { phase := time.Now().Unix() / 600 kpos := keyPos(id + "\n" + fmt.Sprint(phase)) if kpos-i < 5 { - server := hashToDomain[peerHashes[(kpos+4)%len(peerHashes)]] - go http.Post(server+"/dht/"+id+"?phase="+fmt.Sprint(phase)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + server := hashToDomain[peerHashes[(kpos-4+len(peerHashes))%len(peerHashes)]] + go http.Post(server+"/dht/"+id+"?phase="+fmt.Sprint(phase)+"&direct", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) } kpos = keyPos(id + "\n" + fmt.Sprint(phase+1)) if kpos-i < 5 { - server := hashToDomain[peerHashes[(kpos+4)%len(peerHashes)]] - go http.Post(server+"/dht/"+id+"?phase="+fmt.Sprint(phase+1)+"&direct=true", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) + server := hashToDomain[peerHashes[(kpos-4+len(peerHashes))%len(peerHashes)]] + go http.Post(server+"/dht/"+id+"?phase="+fmt.Sprint(phase+1)+"&direct", "application/octet-stream", bytes.NewBuffer(user.dhtVal)) } } } @@ -293,8 +303,10 @@ func redistributeKeys() { for id, user := range users { nextPhase := time.Now().Unix()/600 + 1 if user.phase < nextPhase { - go dhtPost(id, fmt.Sprint(nextPhase), "", user.dhtVal) + go dhtPost(id, fmt.Sprint(nextPhase), false, user.dhtVal) } + user.phase = nextPhase + persist(id) } mu.Unlock() time.Sleep(time.Duration(rand.Intn(300)) * time.Second) diff --git a/server/main.go b/server/main.go index a342ba5..4cf49ed 100644 --- a/server/main.go +++ b/server/main.go @@ -16,6 +16,16 @@ var me string var initialPeer string var dataDir string +type LoggingHandler struct { + Handler http.Handler +} + +// Log all HTTP requests +func (lh LoggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.Printf("Request: %s %s%s", r.Method, me, r.URL.String()) + lh.Handler.ServeHTTP(w, r) +} + func main() { flag.StringVar(&bindAddr, "b", ":4200", "bind address") flag.StringVar(&me, "u", "http://localhost:4200", "public URL") @@ -29,20 +39,11 @@ func main() { myPos = 0 peerHashes = append(peerHashes, sha256sum(me)) hashToDomain = map[string]string{peerHashes[0]: me} - - // Start background functions - if initialPeer != "" { - go addPeer(initialPeer) - } - go cleanPeers() - go cleanKVStore() - go redistributeKeys() + kvstore = make(map[string][]byte) + users = make(map[string]user) // Load user data from disk - err := os.Mkdir(dataDir, 0755) - if err != nil { - log.Fatal(err) - } + os.Mkdir(dataDir, 0755) entries, err := os.ReadDir(dataDir) if err != nil { log.Fatal(err) @@ -51,6 +52,7 @@ func main() { id := entry.Name() reader, err := os.Open(dataDir + "/" + id + "/gob") if err != nil { + log.Fatal(err) continue } var user user @@ -59,13 +61,22 @@ func main() { users[id] = user } + // Start background functions + if initialPeer != "" { + go addPeer(initialPeer) + } + go cleanPeers() + go cleanKVStore() + go redistributeKeys() + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Hello! This is a Kela server.") }) http.HandleFunc("/peer", peerHandler) - http.HandleFunc("/user", userHandler) - http.HandleFunc("/dht", dhtHandler) - http.HandleFunc("/storage", storageHandler) - http.HandleFunc("/message", messageHandler) - log.Fatal(http.ListenAndServe(bindAddr, nil)) + http.HandleFunc("/user/", userHandler) + http.HandleFunc("/dht/", dhtHandler) + http.HandleFunc("/storage/", storageHandler) + http.HandleFunc("/message/", messageHandler) + log.Fatal(http.ListenAndServe(bindAddr, LoggingHandler{Handler: http.DefaultServeMux})) } diff --git a/server/storage.go b/server/storage.go index dd11f22..21950fa 100644 --- a/server/storage.go +++ b/server/storage.go @@ -1,19 +1,176 @@ package main import ( + "crypto/ed25519" + "fmt" + "io" "net/http" + "os" + "strconv" + "strings" + "time" ) +// Replicate a user's log to another server +func replicate(id, s string) { + for true { + mu.Lock() + // Make sure that this server is still the primary for this user + user, ok := users[id] + if !ok { + mu.Unlock() + return + } + if me != user.servers[0] { + user.nextIndex = nil + mu.Unlock() + return + } + // Make sure that the target server is still associated with this user + idx, ok := user.nextIndex[s] + if !ok { + mu.Unlock() + return + } + if idx == len(user.log) { + // Up to date + mu.Unlock() + time.Sleep(50 * time.Millisecond) + continue + } + + op := user.log[idx] + mu.Unlock() + file, _ := os.Open(op) + resp, err := http.Post(s + "/storage/" + id + "/" + op + "?idx=" + fmt.Sprint(idx), "application/octet-stream", file) + if err != nil { + time.Sleep(50 * time.Millisecond) + continue + } + b, err := io.ReadAll(resp.Body) + if err != nil { + time.Sleep(50 * time.Millisecond) + continue + } + + mu.Lock() + user.nextIndex[s], _ = strconv.Atoi(string(b)) + mu.Unlock() + } +} + + // Handle storage requests func storageHandler(w http.ResponseWriter, r *http.Request) { - // filename := r.URL.String()[5:] + pathSplit := strings.Split(r.URL.Path, "/") + id := pathSplit[1] + filename := pathSplit[2] + r.ParseForm() if r.Method == "GET" { + if r.Form.Has("direct") { + // Directly read and respond with file + file, err := os.ReadFile(dataDir + "/" + id + "/" + filename) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + w.Write(file) + return + } + val := dhtGet(id, false) + if verify(id, val) != nil { + w.WriteHeader(http.StatusNotFound) + return + } + if _, ok := users[id]; ok { + reconfigure(id, val) + } + servers := strings.Split(string(val[8:len(val)-ed25519.SignatureSize]), "\n") + if servers[0] == me { + file, err := os.ReadFile(dataDir + "/" + id + "/" + filename) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + w.Write(file) + return + } + for _, server := range servers { + resp, err := http.Get(server + "/storage/" + id + "/" + filename) + if err != nil { + continue + } + b, err := io.ReadAll(resp.Body) + if err != nil { + continue + } + w.Write(b) + return + } + w.WriteHeader(http.StatusNotFound) + } else if r.Method == "POST" { + mu.Lock() + defer mu.Unlock() + user, ok := users[id] + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + + b, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + if verify(id, b) != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } - } else if r.Method == "PUT" { + if r.Form.Has("idx") { + idx, err := strconv.Atoi(r.Form.Get("idx")) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } - } else if r.Method == "DELETE" { + if idx > len(user.log) { + // Missing log entries + w.Write([]byte(fmt.Sprint(len(user.log)))) + return + } + if idx < len(user.log) { + // Too many log entries + ops := make(map[string]interface{}) + for i := idx; i < len(user.log); i++ { + ops[user.log[i]] = nil + } + for op := range ops { + // Fetch older version of file + resp, err := http.Get(user.servers[0] + "/storage/" + id + "/" + op) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + b, err := io.ReadAll(resp.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + err = os.WriteFile(dataDir + "/" + id + "/" + op, b, 0644) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + } + } + } - } else { - w.WriteHeader(http.StatusMethodNotAllowed) + err = os.WriteFile(dataDir + "/" + id + "/" + filename, b, 0644) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + user.log = append(user.log, filename) } } diff --git a/server/test.sh b/server/test.sh index 592ab6d..e4c3476 100755 --- a/server/test.sh +++ b/server/test.sh @@ -1,10 +1,9 @@ #!/bin/bash trap "kill 0" EXIT -go build -./server -d 0 & +go run . -d 0 & for i in $(seq 1 9) do - sleep 0.1 - ./server -d $i -b :420$i -u http://localhost:420$i -i http://localhost:420$((i-1)) & + sleep 0.2 + go run . -d $i -b :420$i -u http://localhost:420$i -i http://localhost:420$((i-1)) & done wait diff --git a/server/user.go b/server/user.go index aa78239..6b55eb4 100644 --- a/server/user.go +++ b/server/user.go @@ -11,8 +11,11 @@ import ( ) type user struct { - dhtVal []byte - phase int64 + dhtVal []byte + phase int64 + servers []string + log []string + nextIndex map[string]int } var users map[string]user @@ -36,7 +39,7 @@ func verify(id string, body []byte) error { // Persist a user's data to disk func persist(id string) { - writer, err := os.Open(dataDir + "/" + id + "/gob") + writer, err := os.OpenFile(dataDir+"/"+id+"/gob", os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return } @@ -44,37 +47,61 @@ func persist(id string) { enc.Encode(users[id]) } +// Reconfigure a user based on a DHT get +func reconfigure(id string, dhtVal []byte) { + mu.Lock() + defer mu.Unlock() + + user := users[id] + if timestamp(dhtVal) < timestamp(user.dhtVal) { + return + } + + user.dhtVal = dhtVal + servers := strings.Split(string(dhtVal[8:len(dhtVal)-ed25519.SignatureSize]), "\n") + if servers[0] == me { + if user.nextIndex == nil { + user.nextIndex = make(map[string]int) + } + for _, server := range servers { + if _, ok := user.nextIndex[server]; !ok { + user.nextIndex[server] = len(user.log) + go replicate(id, server) + } + } + } + inServers := false + for _, server := range servers { + if server == me { + inServers = true + } + } + persist(id) + if !inServers { + delete(users, id) + _ = os.RemoveAll(id) + } +} + // Handle user configuration changes func userHandler(w http.ResponseWriter, r *http.Request) { - id := r.URL.Fragment[6:] + id := r.URL.Path[6:] // Resolve ID to server list - val := dhtGet(id, "") + val := dhtGet(id, false) if verify(id, val) != nil { w.WriteHeader(http.StatusNotFound) return } - // Check if server list contains this server - message := string(val[:len(val)-ed25519.SignatureSize]) - if !strings.Contains(message, me) { - // Delete user if they are no longer associated with this server - delete(users, id) - err := os.RemoveAll(id) - if err != nil { - w.WriteHeader(http.StatusNotFound) - return - } - w.WriteHeader(http.StatusOK) - return - } - //valSplit := strings.Split(message, "\n") - //servers := valSplit[1:len(valSplit)-1] - + mu.Lock() if _, ok := users[id]; !ok { // Add user - users[id] = user{ - dhtVal: val, - } - os.Mkdir(id, 0755) + users[id] = user{dhtVal: val} + os.Mkdir(dataDir+"/"+id, 0755) persist(id) } + mu.Unlock() + + reconfigure(id, val) + + w.WriteHeader(http.StatusOK) } |