From 4368111311c48e73a11a6b24febdcc3be31a2a59 Mon Sep 17 00:00:00 2001 From: Blake Matthes Date: Tue, 29 Jul 2025 20:13:05 -0500 Subject: [PATCH] fix(websocket): errors in saving client info during ping ping --- .../cmd/services/websocket/channel_manager.go | 9 +- .../cmd/services/websocket/log_services.go | 3 +- backend/cmd/services/websocket/ws_client.go | 113 +++++++++++++++++- backend/cmd/services/websocket/ws_handler.go | 37 +++--- backend/utils/db/ws_clients.go | 2 +- 5 files changed, 136 insertions(+), 28 deletions(-) diff --git a/backend/cmd/services/websocket/channel_manager.go b/backend/cmd/services/websocket/channel_manager.go index 930dd1f..6f81650 100644 --- a/backend/cmd/services/websocket/channel_manager.go +++ b/backend/cmd/services/websocket/channel_manager.go @@ -83,6 +83,7 @@ func CleanupChannels() { } func StartBroadcasting(broadcaster chan logging.Message, channels map[string]*Channel) { + logger := logging.New() go func() { for msg := range broadcaster { switch msg.Channel { @@ -90,7 +91,9 @@ func StartBroadcasting(broadcaster chan logging.Message, channels map[string]*Ch // Just forward the message - filtering happens in RunChannel() messageBytes, err := json.Marshal(msg) if err != nil { - log.Printf("Error marshaling message: %v", err) + logger.Error("Error marshaling message", "websocket", map[string]interface{}{ + "errors": err, + }) continue } channels["logServices"].Broadcast <- messageBytes @@ -99,7 +102,9 @@ func StartBroadcasting(broadcaster chan logging.Message, channels map[string]*Ch // Future labels handling messageBytes, err := json.Marshal(msg) if err != nil { - log.Printf("Error marshaling message: %v", err) + logger.Error("Error marshaling message", "websocket", map[string]interface{}{ + "errors": err, + }) continue } channels["labels"].Broadcast <- messageBytes diff --git a/backend/cmd/services/websocket/log_services.go b/backend/cmd/services/websocket/log_services.go index 8fe6d91..e1f6881 100644 --- a/backend/cmd/services/websocket/log_services.go +++ b/backend/cmd/services/websocket/log_services.go @@ -27,9 +27,10 @@ import ( ) func LogServices(broadcaster chan logging.Message) { - fmt.Println("[LogServices] started - single channel for all logs") logger := logging.New() + logger.Info("[LogServices] started - single channel for all logs", "websocket", map[string]interface{}{}) + dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", os.Getenv("DB_HOST"), os.Getenv("DB_PORT"), diff --git a/backend/cmd/services/websocket/ws_client.go b/backend/cmd/services/websocket/ws_client.go index c961a17..21e7f8c 100644 --- a/backend/cmd/services/websocket/ws_client.go +++ b/backend/cmd/services/websocket/ws_client.go @@ -1,6 +1,7 @@ package websocket import ( + "fmt" "log" "sync" "sync/atomic" @@ -9,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" "lst.net/utils/db" + logging "lst.net/utils/logger" ) var ( @@ -30,7 +32,8 @@ type Client struct { ConnectedAt time.Time `json:"connected_at"` done chan struct{} // For graceful shutdown isAlive atomic.Bool - //mu sync.Mutex // Protects isAlive if not using atomic + lastActive time.Time // Tracks last activity + } func (c *Client) SaveToDB() { @@ -59,7 +62,10 @@ func (c *Client) SaveToDB() { } func (c *Client) MarkDisconnected() { - log.Printf("Client %v just lefts us", c.ClientID) + logger := logging.New() + clientData := fmt.Sprintf("Client %v just lefts us", c.ClientID) + logger.Info(clientData, "websocket", map[string]interface{}{}) + now := time.Now() res := db.DB.Model(&db.ClientRecord{}). Where("client_id = ?", c.ClientID). @@ -68,10 +74,17 @@ func (c *Client) MarkDisconnected() { }) if res.RowsAffected == 0 { - log.Println("⚠️ No rows updated for client_id:", c.ClientID) + + logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{ + "clientID": c.ClientID, + }) } if res.Error != nil { - log.Println("❌ Error updating disconnected_at:", res.Error) + + logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{ + "clientID": c.ClientID, + "error": res.Error, + }) } } @@ -123,6 +136,8 @@ const ( ) func (c *Client) StartHeartbeat() { + logger := logging.New() + log.Println("Started hearbeat") ticker := time.NewTicker(pingPeriod) defer ticker.Stop() @@ -140,6 +155,27 @@ func (c *Client) StartHeartbeat() { return } + now := time.Now() + res := db.DB.Model(&db.ClientRecord{}). + Where("client_id = ?", c.ClientID). + Updates(map[string]interface{}{ + "last_heartbeat": &now, + }) + + if res.RowsAffected == 0 { + + logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{ + "clientID": c.ClientID, + }) + } + if res.Error != nil { + + logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{ + "clientID": c.ClientID, + "error": res.Error, + }) + } + case <-c.done: return } @@ -155,6 +191,75 @@ func (c *Client) Close() { } } +func (c *Client) startServerPings() { + ticker := time.NewTicker(60 * time.Second) // Ping every 30s + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + c.Close() // Disconnect if ping fails + return + } + case <-c.done: + return + } + } +} + +func (c *Client) markActive() { + c.lastActive = time.Now() // No mutex needed if single-writer +} + +func (c *Client) IsActive() bool { + return time.Since(c.lastActive) < 45*time.Second // 1.5x ping interval +} + +func (c *Client) updateHeartbeat() { + //fmt.Println("Updating heatbeat") + now := time.Now() + logger := logging.New() + + //fmt.Printf("Updating heartbeat for client: %s at %v\n", c.ClientID, now) + + //db.DB = db.DB.Debug() + res := db.DB.Model(&db.ClientRecord{}). + Where("client_id = ?", c.ClientID). + Updates(map[string]interface{}{ + "last_heartbeat": &now, // Explicit format + }) + //fmt.Printf("Executed SQL: %v\n", db.DB.Statement.SQL.String()) + if res.RowsAffected == 0 { + + logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{ + "clientID": c.ClientID, + }) + } + if res.Error != nil { + + logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{ + "clientID": c.ClientID, + "error": res.Error, + }) + } + // 2. Verify DB connection + if db.DB == nil { + logger.Error("DB connection is nil", "websocket", map[string]interface{}{}) + return + } + + // 3. Test raw SQL execution first + testRes := db.DB.Exec("SELECT 1") + if testRes.Error != nil { + logger.Error("DB ping failed", "websocket", map[string]interface{}{ + "error": testRes.Error, + }) + return + } +} + // work on this stats later // Add to your admin endpoint // type ConnectionStats struct { diff --git a/backend/cmd/services/websocket/ws_handler.go b/backend/cmd/services/websocket/ws_handler.go index a48bde4..b8cb51c 100644 --- a/backend/cmd/services/websocket/ws_handler.go +++ b/backend/cmd/services/websocket/ws_handler.go @@ -35,22 +35,10 @@ func SocketHandler(c *gin.Context, channels map[string]*Channel) { } //defer conn.Close() - // Set ping handler on the connection - conn.SetPingHandler(func(appData string) error { - log.Println("Received ping:", appData) - conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // Reset read timeout - return nil // Return nil to send pong automatically - }) - - // Optional: Custom pong handler - conn.SetPongHandler(func(appData string) error { - log.Println("Received pong:", appData) - return nil - }) - // Create new client client := &Client{ Conn: conn, + APIKey: "exampleAPIKey", Send: make(chan []byte, 256), // Buffered channel Channels: make(map[string]bool), IPAddress: c.ClientIP(), @@ -73,8 +61,22 @@ func SocketHandler(c *gin.Context, channels map[string]*Channel) { // return // } - //client.StartHeartbeat() - // Cleanup on disconnect + // Set handlers + conn.SetPingHandler(func(string) error { + return nil // Auto-responds with pong + }) + + conn.SetPongHandler(func(string) error { + now := time.Now() + client.markActive() // Track last pong time + client.lastActive = now + client.updateHeartbeat() + return nil + }) + + // Start server-side ping ticker + go client.startServerPings() + defer func() { // Unregister from all channels for channelName := range client.Channels { @@ -96,11 +98,6 @@ func SocketHandler(c *gin.Context, channels map[string]*Channel) { log.Printf("Client disconnected: %s", client.ClientID) }() - client.Conn.SetPingHandler(func(appData string) error { - log.Printf("Custom ping handler for client %s", client.ClientID) - return nil - }) - // Send welcome message immediately welcomeMsg := map[string]string{ "status": "connected", diff --git a/backend/utils/db/ws_clients.go b/backend/utils/db/ws_clients.go index f8a389f..e42b3f8 100644 --- a/backend/utils/db/ws_clients.go +++ b/backend/utils/db/ws_clients.go @@ -12,7 +12,7 @@ type ClientRecord struct { IPAddress string `gorm:"not null"` UserAgent string `gorm:"size:255"` ConnectedAt time.Time `gorm:"index"` - LastHeartbeat time.Time `gorm:"index"` + LastHeartbeat time.Time `gorm:"column:last_heartbeat"` Channels JSONB `gorm:"type:jsonb"` CreatedAt time.Time UpdatedAt time.Time