I'm writing a Go websocket server and I want to graceful stop the connections when my server goes down. I have a map of active connections stored in the following variable:
var connections = make(map[string]*websocket.Conn)
My main function looks like this:
func main() {
// ... stuff ....
gracefulStop := make(chan os.Signal)
signal.Notify(gracefulStop, syscall.SIGTERM)
signal.Notify(gracefulStop, syscall.SIGINT)
signal.Notify(gracefulStop, syscall.SIGQUIT)
signal.Notify(gracefulStop, syscall.SIGKILL)
signal.Notify(gracefulStop, syscall.SIGHUP)
go func() {
sig := <-gracefulStop
log.Printf("Exiting from process due to %+v", sig)
log.Println("Closing all websocket connections")
for id, conn := range connections {
closeConnection(id, conn)
}
os.Exit(0)
}()
r := mux.NewRouter()
r.HandleFunc("/{id}", wsHandler)
err := http.ListenAndServe(fmt.Sprintf(":%d", *argPort), r)
if err != nil {
log.Println("Could not start http server")
log.Println(err)
}
}
closeConnection
does 4 things:
conn.Close()
- sets conn as nil
- removes the id from the map
- calls an AWS Lambda function
The same function is called as a defer
function inside the wsHandler
function, so if a client disconnects by its own, I execute function in the handler.
It's all working nicely, except that when I ctrl+c the server my closeConnection
function is called two times per client, one in the graceful stop handler and the other in the wsHandler defer function.
I tried to check in my closeConnection
function if the connection is still defined in connections
, but it returns true both of the times.
I thought that it was due to the fact that they are called two times because they are in different goroutines, so I replaced the for loop above with just a time.Sleep(2 * time.Second)
, but in this case nothing happens (the closeConnection
inside the wsHandler
defer function is not even called).
This is what I mean:
go func() {
sig := <-gracefulStop
log.Printf("Exiting from process due to %+v", sig)
log.Println("Closing all websocket connections")
// for chargeboxIdentity, conn := range connections {
// chargeboxDisconnected(chargeboxIdentity, conn)
// }
time.Sleep(2 * time.Second)
os.Exit(0)
}()
EDIT: Here is the closeConnection
function:
func closeConnection(id string, conn *websocket.Conn) {
_, ok := connections[id]
log.Println(ok)
log.Printf("%s (%s) disconnected", id, conn.RemoteAddr())
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
time.Sleep(300 * time.Millisecond)
conn.Close()
conn = nil
delete(connections, id)
request := LambdaPayload{ID: id}
payload, err := json.Marshal(request)
if err != nil {
log.Println("Could not create payload for lambda call")
log.Println(err)
return
}
_, err = client.Invoke(&lambda.InvokeInput{FunctionName: aws.String(lambdaPrefix + "MainDisconnect"), Payload: payload})
if err != nil {
log.Println("Disconnect Lambda returned an error")
log.Println(err)
}
}
EDIT: Here's the wsHandler
function:
func wsHandler(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("Could not upgrade websocket connection")
log.Println(err)
return
}
vars := mux.Vars(r)
if !clientConnected(vars["id"], conn) {
return
}
defer closeConnection(vars["id"], conn)
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
break
}
log.Printf("%s sent: %s", vars["id"], string(msg))
// ... stuff ...
}
}