package websocket import ( "log" "sync" "sync/atomic" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "lst.net/utils/db" ) var ( clients = make(map[*Client]bool) clientsMu sync.RWMutex ) type Client struct { ClientID uuid.UUID `json:"client_id"` Conn *websocket.Conn `json:"-"` // Excluded from JSON APIKey string `json:"api_key"` IPAddress string `json:"ip_address"` UserAgent string `json:"user_agent"` Send chan []byte `json:"-"` // Excluded from JSON Channels map[string]bool `json:"channels"` LogLevels []string `json:"levels,omitempty"` Services []string `json:"services,omitempty"` Labels []string `json:"labels,omitempty"` ConnectedAt time.Time `json:"connected_at"` done chan struct{} // For graceful shutdown isAlive atomic.Bool //mu sync.Mutex // Protects isAlive if not using atomic } func (c *Client) SaveToDB() { // Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB channels := make(map[string]interface{}) for ch := range c.Channels { channels[ch] = true } clientRecord := &db.ClientRecord{ APIKey: c.APIKey, IPAddress: c.IPAddress, UserAgent: c.UserAgent, Channels: db.JSONB(channels), ConnectedAt: time.Now(), LastHeartbeat: time.Now(), } if err := db.DB.Create(&clientRecord).Error; err != nil { log.Println("❌ Error saving client:", err) } else { c.ClientID = clientRecord.ClientID c.ConnectedAt = clientRecord.ConnectedAt } } func (c *Client) MarkDisconnected() { log.Printf("Client %v just lefts us", c.ClientID) now := time.Now() res := db.DB.Model(&db.ClientRecord{}). Where("client_id = ?", c.ClientID). Updates(map[string]interface{}{ "disconnected_at": &now, }) if res.RowsAffected == 0 { log.Println("⚠️ No rows updated for client_id:", c.ClientID) } if res.Error != nil { log.Println("❌ Error updating disconnected_at:", res.Error) } } func (c *Client) SafeClient() *Client { return &Client{ ClientID: c.ClientID, APIKey: c.APIKey, IPAddress: c.IPAddress, UserAgent: c.UserAgent, Channels: c.Channels, LogLevels: c.LogLevels, Services: c.Services, Labels: c.Labels, ConnectedAt: c.ConnectedAt, } } // GetAllClients returns safe representations of all clients func GetAllClients() []*Client { clientsMu.RLock() defer clientsMu.RUnlock() var clientList []*Client for client := range clients { clientList = append(clientList, client.SafeClient()) } return clientList } // GetClientsByChannel returns clients in a specific channel func GetClientsByChannel(channel string) []*Client { clientsMu.RLock() defer clientsMu.RUnlock() var channelClients []*Client for client := range clients { if client.Channels[channel] { channelClients = append(channelClients, client.SafeClient()) } } return channelClients } // heat beat stuff const ( pingPeriod = 30 * time.Second pongWait = 60 * time.Second writeWait = 10 * time.Second ) func (c *Client) StartHeartbeat() { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { case <-ticker.C: if !c.isAlive.Load() { // Correct way to read atomic.Bool return } c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { log.Printf("Heartbeat failed for %s: %v", c.ClientID, err) c.Close() return } case <-c.done: return } } } func (c *Client) Close() { if c.isAlive.CompareAndSwap(true, false) { // Atomic swap close(c.done) c.Conn.Close() // Add any other cleanup here c.MarkDisconnected() } } // work on this stats later // Add to your admin endpoint // type ConnectionStats struct { // TotalConnections int `json:"total_connections"` // ActiveConnections int `json:"active_connections"` // AvgDuration string `json:"avg_duration"` // } // func GetConnectionStats() ConnectionStats { // // Implement your metrics tracking // }