package ws import ( "fmt" "sync" "sync/atomic" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "gorm.io/gorm" "lst.net/internal/models" "lst.net/pkg" "lst.net/pkg/logger" ) var ( clients = make(map[*Client]bool) clientsMu sync.RWMutex ) type Client struct { ClientID uuid.UUID `json:"client_id"` Conn *websocket.Conn `json:"-"` // Excluded from JSON APIKey string `json:"api_key"` IPAddress string `json:"ip_address"` UserAgent string `json:"user_agent"` Send chan []byte `json:"-"` // Excluded from JSON Channels map[string]bool `json:"channels"` LogLevels []string `json:"levels,omitempty"` Services []string `json:"services,omitempty"` Labels []string `json:"labels,omitempty"` ConnectedAt time.Time `json:"connected_at"` done chan struct{} // For graceful shutdown isAlive atomic.Bool lastActive time.Time // Tracks last activity } func (c *Client) SaveToDB(log *logger.CustomLogger, db *gorm.DB) { // Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB channels := make(map[string]interface{}) for ch := range c.Channels { channels[ch] = true } clientRecord := &models.ClientRecord{ APIKey: c.APIKey, IPAddress: c.IPAddress, UserAgent: c.UserAgent, Channels: pkg.JSONB(channels), ConnectedAt: time.Now(), LastHeartbeat: time.Now(), } if err := db.Create(&clientRecord).Error; err != nil { log.Error("❌ Error saving client", "websocket", map[string]interface{}{ "error": err, }) } else { c.ClientID = clientRecord.ClientID c.ConnectedAt = clientRecord.ConnectedAt clientData := fmt.Sprintf("A new client %v, just connected", c.ClientID) log.Info(clientData, "websocket", map[string]interface{}{}) } } func (c *Client) MarkDisconnected(log *logger.CustomLogger, db *gorm.DB) { clientData := fmt.Sprintf("Client %v Dicconected", c.ClientID) log.Info(clientData, "websocket", map[string]interface{}{}) now := time.Now() res := db.Model(&models.ClientRecord{}). Where("client_id = ?", c.ClientID). Updates(map[string]interface{}{ "disconnected_at": &now, }) if res.RowsAffected == 0 { log.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{ "clientID": c.ClientID, }) } if res.Error != nil { log.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{ "clientID": c.ClientID, "error": res.Error, }) } } func (c *Client) SafeClient() *Client { return &Client{ ClientID: c.ClientID, APIKey: c.APIKey, IPAddress: c.IPAddress, UserAgent: c.UserAgent, Channels: c.Channels, LogLevels: c.LogLevels, Services: c.Services, Labels: c.Labels, ConnectedAt: c.ConnectedAt, } } // GetAllClients returns safe representations of all clients func GetAllClients() []*Client { clientsMu.RLock() defer clientsMu.RUnlock() var clientList []*Client for client := range clients { clientList = append(clientList, client.SafeClient()) } return clientList } // GetClientsByChannel returns clients in a specific channel func GetClientsByChannel(channel string) []*Client { clientsMu.RLock() defer clientsMu.RUnlock() var channelClients []*Client for client := range clients { if client.Channels[channel] { channelClients = append(channelClients, client.SafeClient()) } } return channelClients } // heat beat stuff const ( pingPeriod = 30 * time.Second pongWait = 60 * time.Second writeWait = 10 * time.Second ) func (c *Client) StartHeartbeat(log *logger.CustomLogger, db *gorm.DB) { log.Debug("Started hearbeat", "websocket", map[string]interface{}{}) ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { case <-ticker.C: if !c.isAlive.Load() { return } c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { log.Error("Heartbeat failed", "websocket", map[string]interface{}{ "client_id": c.ClientID, "error": err, }) c.Close(log, db) return } now := time.Now() res := db.Model(&models.ClientRecord{}). Where("client_id = ?", c.ClientID). Updates(map[string]interface{}{ "last_heartbeat": &now, }) if res.RowsAffected == 0 { log.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{ "clientID": c.ClientID, }) } if res.Error != nil { log.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{ "clientID": c.ClientID, "error": res.Error, }) } clientStuff := fmt.Sprintf("HeartBeat just done on: %v", c.ClientID) log.Info(clientStuff, "websocket", map[string]interface{}{ "clientID": c.ClientID, }) case <-c.done: return } } } func (c *Client) Close(log *logger.CustomLogger, db *gorm.DB) { if c.isAlive.CompareAndSwap(true, false) { // Atomic swap close(c.done) c.Conn.Close() // Add any other cleanup here c.MarkDisconnected(log, db) } } func (c *Client) startServerPings(log *logger.CustomLogger, db *gorm.DB) { ticker := time.NewTicker(60 * time.Second) // Ping every 30s defer ticker.Stop() for { select { case <-ticker.C: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { log.Error("Server Ping failed", "websocket", map[string]interface{}{ "clientID": c.ClientID, "error": err, }) c.Close(log, db) return } case <-c.done: return } } } func (c *Client) markActive() { c.lastActive = time.Now() // No mutex needed if single-writer } func (c *Client) IsActive() bool { return time.Since(c.lastActive) < 45*time.Second // 1.5x ping interval } func (c *Client) updateHeartbeat(log *logger.CustomLogger, db *gorm.DB) { //fmt.Println("Updating heatbeat") now := time.Now() //fmt.Printf("Updating heartbeat for client: %s at %v\n", c.ClientID, now) //db.DB = db.DB.Debug() res := db.Model(&models.ClientRecord{}). Where("client_id = ?", c.ClientID). Updates(map[string]interface{}{ "last_heartbeat": &now, // Explicit format }) //fmt.Printf("Executed SQL: %v\n", db.DB.Statement.SQL.String()) if res.RowsAffected == 0 { log.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{ "clientID": c.ClientID, }) } if res.Error != nil { log.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{ "clientID": c.ClientID, "error": res.Error, }) } // 2. Verify DB connection if db == nil { log.Error("DB connection is nil", "websocket", map[string]interface{}{}) return } // 3. Test raw SQL execution first testRes := db.Exec("SELECT 1") if testRes.Error != nil { log.Error("DB ping failed", "websocket", map[string]interface{}{ "error": testRes.Error, }) return } } // work on this stats later // Add to your admin endpoint // type ConnectionStats struct { // TotalConnections int `json:"total_connections"` // ActiveConnections int `json:"active_connections"` // AvgDuration string `json:"avg_duration"` // } // func GetConnectionStats() ConnectionStats { // // Implement your metrics tracking // }