From a1a30cffd18e02e1061959fa3164f8237522880c Mon Sep 17 00:00:00 2001 From: Blake Matthes Date: Tue, 29 Jul 2025 11:29:59 -0500 Subject: [PATCH] refactor(ws): ws logging and channel manager added no auth currently --- .../cmd/services/websocket/channel_manager.go | 174 ++++++++++++++ backend/cmd/services/websocket/client.go | 93 ------- backend/cmd/services/websocket/handler.go | 163 ------------- backend/cmd/services/websocket/label.go | 24 ++ .../allLogs.go => log_services.go} | 54 ++--- backend/cmd/services/websocket/routes.go | 58 +++-- backend/cmd/services/websocket/ws_client.go | 168 +++++++++++++ backend/cmd/services/websocket/ws_handler.go | 227 ++++++++++++++++++ backend/main.go | 10 +- backend/utils/db/db.go | 18 +- backend/utils/db/{config.go => settings.go} | 18 +- backend/utils/logger/logger.go | 5 +- 12 files changed, 693 insertions(+), 319 deletions(-) create mode 100644 backend/cmd/services/websocket/channel_manager.go delete mode 100644 backend/cmd/services/websocket/client.go delete mode 100644 backend/cmd/services/websocket/handler.go create mode 100644 backend/cmd/services/websocket/label.go rename backend/cmd/services/websocket/{channelMGT/allLogs.go => log_services.go} (62%) create mode 100644 backend/cmd/services/websocket/ws_client.go create mode 100644 backend/cmd/services/websocket/ws_handler.go rename backend/utils/db/{config.go => settings.go} (81%) diff --git a/backend/cmd/services/websocket/channel_manager.go b/backend/cmd/services/websocket/channel_manager.go new file mode 100644 index 0000000..930dd1f --- /dev/null +++ b/backend/cmd/services/websocket/channel_manager.go @@ -0,0 +1,174 @@ +package websocket + +import ( + "encoding/json" + "log" + "strings" + "sync" + + logging "lst.net/utils/logger" +) + +type Channel struct { + Name string + Clients map[*Client]bool + Register chan *Client + Unregister chan *Client + Broadcast chan []byte + lock sync.RWMutex +} + +var ( + channels = make(map[string]*Channel) + channelsMu sync.RWMutex +) + +// InitializeChannels creates and returns all channels +func InitializeChannels() { + channelsMu.Lock() + defer channelsMu.Unlock() + + channels["logServices"] = NewChannel("logServices") + channels["labels"] = NewChannel("labels") + // Add more channels here as needed +} + +func NewChannel(name string) *Channel { + return &Channel{ + Name: name, + Clients: make(map[*Client]bool), + Register: make(chan *Client), + Unregister: make(chan *Client), + Broadcast: make(chan []byte), + } +} + +func GetChannel(name string) (*Channel, bool) { + channelsMu.RLock() + defer channelsMu.RUnlock() + ch, exists := channels[name] + return ch, exists +} + +func GetAllChannels() map[string]*Channel { + channelsMu.RLock() + defer channelsMu.RUnlock() + + chs := make(map[string]*Channel) + for k, v := range channels { + chs[k] = v + } + return chs +} + +func StartAllChannels() { + + channelsMu.RLock() + defer channelsMu.RUnlock() + + for _, ch := range channels { + go ch.RunChannel() + } +} + +func CleanupChannels() { + channelsMu.Lock() + defer channelsMu.Unlock() + + for _, ch := range channels { + close(ch.Broadcast) + // Add any other cleanup needed + } + channels = make(map[string]*Channel) +} + +func StartBroadcasting(broadcaster chan logging.Message, channels map[string]*Channel) { + go func() { + for msg := range broadcaster { + switch msg.Channel { + case "logServices": + // Just forward the message - filtering happens in RunChannel() + messageBytes, err := json.Marshal(msg) + if err != nil { + log.Printf("Error marshaling message: %v", err) + continue + } + channels["logServices"].Broadcast <- messageBytes + + case "labels": + // Future labels handling + messageBytes, err := json.Marshal(msg) + if err != nil { + log.Printf("Error marshaling message: %v", err) + continue + } + channels["labels"].Broadcast <- messageBytes + + default: + log.Printf("Received message for unknown channel: %s", msg.Channel) + } + } + }() +} + +func contains(slice []string, item string) bool { + // Empty filter slice means "match all" + if len(slice) == 0 { + return true + } + + // Case-insensitive comparison + item = strings.ToLower(item) + for _, s := range slice { + if strings.ToLower(s) == item { + return true + } + } + return false +} + +// Updated Channel.RunChannel() for logServices filtering +func (ch *Channel) RunChannel() { + for { + select { + case client := <-ch.Register: + ch.lock.Lock() + ch.Clients[client] = true + ch.lock.Unlock() + + case client := <-ch.Unregister: + ch.lock.Lock() + delete(ch.Clients, client) + ch.lock.Unlock() + + case message := <-ch.Broadcast: + var msg logging.Message + if err := json.Unmarshal(message, &msg); err != nil { + continue + } + + ch.lock.RLock() + for client := range ch.Clients { + // Special filtering for logServices + if ch.Name == "logServices" { + logLevel, _ := msg.Meta["level"].(string) + logService, _ := msg.Meta["service"].(string) + + levelMatch := len(client.LogLevels) == 0 || contains(client.LogLevels, logLevel) + serviceMatch := len(client.Services) == 0 || contains(client.Services, logService) + + if !levelMatch || !serviceMatch { + continue + } + } + + select { + case client.Send <- message: + default: + ch.Unregister <- client + } + } + ch.lock.RUnlock() + } + } +} diff --git a/backend/cmd/services/websocket/client.go b/backend/cmd/services/websocket/client.go deleted file mode 100644 index 5525296..0000000 --- a/backend/cmd/services/websocket/client.go +++ /dev/null @@ -1,93 +0,0 @@ -package socketio - -import ( - "log" - "sync" - - "github.com/google/uuid" - "github.com/gorilla/websocket" - logging "lst.net/utils/logger" -) - -type Client struct { - ClientID uuid.UUID - Conn *websocket.Conn - APIKey string - IPAddress string - UserAgent string - Send chan []byte - Channels map[string]bool // e.g., {"logs": true, "labels": true} -} - -var clients = make(map[*Client]bool) - -var clientsLock sync.Mutex - -func init() { - var broadcast = make(chan string) - go func() { - for { - msg := <-broadcast - - clientsLock.Lock() - for client := range clients { - if client.Channels["logs"] { - err := client.Conn.WriteMessage(websocket.TextMessage, []byte(msg)) - if err != nil { - log.Println("Write error:", err) - client.Conn.Close() - //client.MarkDisconnected() - delete(clients, client) - } - } - } - clientsLock.Unlock() - } - }() -} - -func StartBroadcasting(broadcaster chan logging.Message) { - go func() { - log.Println("StartBroadcasting goroutine started") - for { - msg := <-broadcaster - //log.Printf("Received msg on broadcaster: %+v\n", msg) - clientsLock.Lock() - for client := range clients { - if client.Channels[msg.Channel] { - log.Println("Sending message to client") - err := client.Conn.WriteJSON(msg) - if err != nil { - log.Println("Write error:", err) - client.Conn.Close() - client.MarkDisconnected() - delete(clients, client) - } - } else { - log.Println("Skipping client, channel mismatch") - } - } - clientsLock.Unlock() - } - }() -} - -// func (c *Client) JoinChannel(name string) { -// ch := GetOrCreateChannel(name) -// c.Channels[name] = ch -// ch.Register <- c -// } - -// func (c *Client) LeaveChannel(name string) { -// if ch, ok := c.Channels[name]; ok { -// ch.Unregister <- c -// delete(c.Channels, name) -// } -// } - -func (c *Client) Disconnect() { - // for _, ch := range c.Channels { - // ch.Unregister <- c - // } - close(c.Send) -} diff --git a/backend/cmd/services/websocket/handler.go b/backend/cmd/services/websocket/handler.go deleted file mode 100644 index c10b2e5..0000000 --- a/backend/cmd/services/websocket/handler.go +++ /dev/null @@ -1,163 +0,0 @@ -package socketio - -import ( - "encoding/json" - "log" - "net/http" - "time" - - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "lst.net/utils/db" -) - -type JoinPayload struct { - Channel string `json:"channel"` - Services []string `json:"services,omitempty"` - APIKey string `json:"apiKey"` -} - -// type Channel struct { -// Name string -// Clients map[*Client]bool -// Register chan *Client -// Unregister chan *Client -// Broadcast chan Message -// } - -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, // allow all origins; customize for prod -} - -func SocketHandler(c *gin.Context) { - // Upgrade HTTP to websocket - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - log.Println("Failed to upgrade:", err) - c.AbortWithStatus(http.StatusBadRequest) - return - } - defer conn.Close() - - // Create client struct - client := &Client{ - Conn: conn, - IPAddress: c.ClientIP(), - UserAgent: c.Request.UserAgent(), - Channels: make(map[string]bool), - } - - clientsLock.Lock() - clients[client] = true - clientsLock.Unlock() - - defer func() { - clientsLock.Lock() - delete(clients, client) - clientsLock.Unlock() - client.MarkDisconnected() - client.Disconnect() - conn.Close() - }() - - for { - // Read message from client - _, msg, err := conn.ReadMessage() - if err != nil { - log.Println("Read error:", err) - clientsLock.Lock() - delete(clients, client) - clientsLock.Unlock() - client.MarkDisconnected() - client.Disconnect() - break - } - - var payload JoinPayload - err = json.Unmarshal(msg, &payload) - if err != nil { - log.Println("Invalid JSON payload:", err) - clientsLock.Lock() - delete(clients, client) - clientsLock.Unlock() - client.MarkDisconnected() - client.Disconnect() - continue - } - - // Simple API key check (replace with real auth) - if payload.APIKey == "" { - conn.WriteMessage(websocket.TextMessage, []byte("Missing API Key")) - continue - } - client.APIKey = payload.APIKey - - // Handle channel subscription, add more here as we get more in. - switch payload.Channel { - case "logs": - client.Channels["logs"] = true - case "logServices": - for _, svc := range payload.Services { - client.Channels["logServices:"+svc] = true - } - case "labels": - client.Channels["labels"] = true - default: - conn.WriteMessage(websocket.TextMessage, []byte("Unknown channel")) - continue - } - - // Save client info in DB - client.SaveToDB() - - // Confirm subscription - resp := map[string]string{ - "status": "subscribed", - "channel": payload.Channel, - } - respJSON, _ := json.Marshal(resp) - conn.WriteMessage(websocket.TextMessage, respJSON) - - // You could now start pushing messages to client or keep connection open - // For demo, just wait and keep connection alive - } -} - -func (c *Client) SaveToDB() { - // Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB - channels := make(map[string]interface{}) - for ch := range c.Channels { - channels[ch] = true - } - - clientRecord := &db.ClientRecord{ - APIKey: c.APIKey, - IPAddress: c.IPAddress, - UserAgent: c.UserAgent, - Channels: db.JSONB(channels), - ConnectedAt: time.Now(), - LastHeartbeat: time.Now(), - } - - if err := db.DB.Create(&clientRecord).Error; err != nil { - log.Println("❌ Error saving client:", err) - } else { - c.ClientID = clientRecord.ClientID // ✅ Assign the generated UUID back to the client - } -} - -func (c *Client) MarkDisconnected() { - now := time.Now() - res := db.DB.Model(&db.ClientRecord{}). - Where("client_id = ?", c.ClientID). - Updates(map[string]interface{}{ - "disconnected_at": &now, - }) - - if res.RowsAffected == 0 { - log.Println("⚠️ No rows updated for client_id:", c.ClientID) - } - if res.Error != nil { - log.Println("❌ Error updating disconnected_at:", res.Error) - } -} diff --git a/backend/cmd/services/websocket/label.go b/backend/cmd/services/websocket/label.go new file mode 100644 index 0000000..1cbed18 --- /dev/null +++ b/backend/cmd/services/websocket/label.go @@ -0,0 +1,24 @@ +package websocket + +import logging "lst.net/utils/logger" + +func LabelProcessor(broadcaster chan logging.Message) { + // Initialize any label-specific listeners + // This could listen to a different PG channel or process differently + + // for { + // select { + // // Implementation depends on your label data source + // // Example: + // case labelEvent := <-someLabelChannel: + // broadcaster <- logging.Message{ + // Channel: "labels", + // Data: labelEvent.Data, + // Meta: map[string]interface{}{ + // "label": labelEvent.Label, + // "type": labelEvent.Type, + // }, + // } + // } + // } +} diff --git a/backend/cmd/services/websocket/channelMGT/allLogs.go b/backend/cmd/services/websocket/log_services.go similarity index 62% rename from backend/cmd/services/websocket/channelMGT/allLogs.go rename to backend/cmd/services/websocket/log_services.go index eb2e107..8fe6d91 100644 --- a/backend/cmd/services/websocket/channelMGT/allLogs.go +++ b/backend/cmd/services/websocket/log_services.go @@ -1,15 +1,4 @@ -package channelmgt - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "time" - - "github.com/lib/pq" - logging "lst.net/utils/logger" -) +package websocket // setup the notifiyer @@ -27,9 +16,20 @@ import ( // AFTER INSERT ON logs // FOR EACH ROW EXECUTE FUNCTION notify_new_log(); -func AllLogs(db *sql.DB, broadcaster chan logging.Message) { - fmt.Println("[AllLogs] started") - log := logging.New() +import ( + "encoding/json" + "fmt" + "os" + "time" + + "github.com/lib/pq" + logging "lst.net/utils/logger" +) + +func LogServices(broadcaster chan logging.Message) { + fmt.Println("[LogServices] started - single channel for all logs") + logger := logging.New() + dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", os.Getenv("DB_HOST"), os.Getenv("DB_PORT"), @@ -41,41 +41,37 @@ func AllLogs(db *sql.DB, broadcaster chan logging.Message) { listener := pq.NewListener(dsn, 10*time.Second, time.Minute, nil) err := listener.Listen("new_log") if err != nil { - log.Panic("Failed to LISTEN on new_log", "logger", map[string]interface{}{ + logger.Panic("Failed to LISTEN on new_log", "logger", map[string]interface{}{ "error": err.Error(), }) } - fmt.Println("Listening for new logs...") - + fmt.Println("Listening for all logs through single logServices channel...") for { select { case notify := <-listener.Notify: if notify != nil { - fmt.Println("New log notification received") - - // Unmarshal the JSON payload of the inserted row var logData map[string]interface{} if err := json.Unmarshal([]byte(notify.Extra), &logData); err != nil { - log.Error("Failed to unmarshal notification payload", "logger", map[string]interface{}{ + logger.Error("Failed to unmarshal notification payload", "logger", map[string]interface{}{ "error": err.Error(), }) continue } - // Build message to broadcast - msg := logging.Message{ - Channel: "logs", // This matches your logs channel name + // Always send to logServices channel + broadcaster <- logging.Message{ + Channel: "logServices", Data: logData, + Meta: map[string]interface{}{ + "level": logData["level"], + "service": logData["service"], + }, } - - broadcaster <- msg - //fmt.Printf("[Broadcasting] sending: %+v\n", msg) } case <-time.After(90 * time.Second): go func() { - log.Debug("Re-pinging Postgres LISTEN", "logger", map[string]interface{}{}) listener.Ping() }() } diff --git a/backend/cmd/services/websocket/routes.go b/backend/cmd/services/websocket/routes.go index b2e2890..e9418db 100644 --- a/backend/cmd/services/websocket/routes.go +++ b/backend/cmd/services/websocket/routes.go @@ -1,25 +1,55 @@ -package socketio +package websocket import ( - "github.com/gin-gonic/gin" + "net/http" - channelmgt "lst.net/cmd/services/websocket/channelMGT" - "lst.net/utils/db" + "github.com/gin-gonic/gin" logging "lst.net/utils/logger" ) -var broadcaster = make(chan logging.Message) // define broadcaster here so it’s accessible +var ( + broadcaster = make(chan logging.Message) +) func RegisterSocketRoutes(r *gin.Engine) { - sqlDB, err := db.DB.DB() - if err != nil { - panic(err) + // Initialize all channels + InitializeChannels() + + // Start channel processors + StartAllChannels() + + // Start background services + go LogServices(broadcaster) + go StartBroadcasting(broadcaster, channels) + + // WebSocket route + r.GET("/ws", func(c *gin.Context) { + SocketHandler(c, channels) + }) + + r.GET("/ws/clients", AdminAuthMiddleware(), handleGetClients) +} + +func handleGetClients(c *gin.Context) { + channel := c.Query("channel") + + var clientList []*Client + if channel != "" { + clientList = GetClientsByChannel(channel) + } else { + clientList = GetAllClients() } - // channels - go channelmgt.AllLogs(sqlDB, broadcaster) - go StartBroadcasting(broadcaster) - - wsGroup := r.Group("/ws") - wsGroup.GET("/connect", SocketHandler) + c.JSON(http.StatusOK, gin.H{ + "count": len(clientList), + "clients": clientList, + }) +} + +func AdminAuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // Implement your admin authentication logic + // Example: Check API key or JWT token + c.Next() + } } diff --git a/backend/cmd/services/websocket/ws_client.go b/backend/cmd/services/websocket/ws_client.go new file mode 100644 index 0000000..c961a17 --- /dev/null +++ b/backend/cmd/services/websocket/ws_client.go @@ -0,0 +1,168 @@ +package websocket + +import ( + "log" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "lst.net/utils/db" +) + +var ( + clients = make(map[*Client]bool) + clientsMu sync.RWMutex +) + +type Client struct { + ClientID uuid.UUID `json:"client_id"` + Conn *websocket.Conn `json:"-"` // Excluded from JSON + APIKey string `json:"api_key"` + IPAddress string `json:"ip_address"` + UserAgent string `json:"user_agent"` + Send chan []byte `json:"-"` // Excluded from JSON + Channels map[string]bool `json:"channels"` + LogLevels []string `json:"levels,omitempty"` + Services []string `json:"services,omitempty"` + Labels []string `json:"labels,omitempty"` + ConnectedAt time.Time `json:"connected_at"` + done chan struct{} // For graceful shutdown + isAlive atomic.Bool + //mu sync.Mutex // Protects isAlive if not using atomic +} + +func (c *Client) SaveToDB() { + // Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB + channels := make(map[string]interface{}) + for ch := range c.Channels { + channels[ch] = true + } + + clientRecord := &db.ClientRecord{ + APIKey: c.APIKey, + IPAddress: c.IPAddress, + UserAgent: c.UserAgent, + Channels: db.JSONB(channels), + ConnectedAt: time.Now(), + LastHeartbeat: time.Now(), + } + + if err := db.DB.Create(&clientRecord).Error; err != nil { + log.Println("❌ Error saving client:", err) + + } else { + c.ClientID = clientRecord.ClientID + c.ConnectedAt = clientRecord.ConnectedAt + } +} + +func (c *Client) MarkDisconnected() { + log.Printf("Client %v just lefts us", c.ClientID) + now := time.Now() + res := db.DB.Model(&db.ClientRecord{}). + Where("client_id = ?", c.ClientID). + Updates(map[string]interface{}{ + "disconnected_at": &now, + }) + + if res.RowsAffected == 0 { + log.Println("⚠️ No rows updated for client_id:", c.ClientID) + } + if res.Error != nil { + log.Println("❌ Error updating disconnected_at:", res.Error) + } +} + +func (c *Client) SafeClient() *Client { + return &Client{ + ClientID: c.ClientID, + APIKey: c.APIKey, + IPAddress: c.IPAddress, + UserAgent: c.UserAgent, + Channels: c.Channels, + LogLevels: c.LogLevels, + Services: c.Services, + Labels: c.Labels, + ConnectedAt: c.ConnectedAt, + } +} + +// GetAllClients returns safe representations of all clients +func GetAllClients() []*Client { + clientsMu.RLock() + defer clientsMu.RUnlock() + + var clientList []*Client + for client := range clients { + clientList = append(clientList, client.SafeClient()) + } + return clientList +} + +// GetClientsByChannel returns clients in a specific channel +func GetClientsByChannel(channel string) []*Client { + clientsMu.RLock() + defer clientsMu.RUnlock() + + var channelClients []*Client + for client := range clients { + if client.Channels[channel] { + channelClients = append(channelClients, client.SafeClient()) + } + } + return channelClients +} + +// heat beat stuff +const ( + pingPeriod = 30 * time.Second + pongWait = 60 * time.Second + writeWait = 10 * time.Second +) + +func (c *Client) StartHeartbeat() { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if !c.isAlive.Load() { // Correct way to read atomic.Bool + return + } + + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + log.Printf("Heartbeat failed for %s: %v", c.ClientID, err) + c.Close() + return + } + + case <-c.done: + return + } + } +} + +func (c *Client) Close() { + if c.isAlive.CompareAndSwap(true, false) { // Atomic swap + close(c.done) + c.Conn.Close() + // Add any other cleanup here + c.MarkDisconnected() + } +} + +// work on this stats later +// Add to your admin endpoint +// type ConnectionStats struct { +// TotalConnections int `json:"total_connections"` +// ActiveConnections int `json:"active_connections"` +// AvgDuration string `json:"avg_duration"` +// } + +// func GetConnectionStats() ConnectionStats { +// // Implement your metrics tracking +// } diff --git a/backend/cmd/services/websocket/ws_handler.go b/backend/cmd/services/websocket/ws_handler.go new file mode 100644 index 0000000..a48bde4 --- /dev/null +++ b/backend/cmd/services/websocket/ws_handler.go @@ -0,0 +1,227 @@ +package websocket + +import ( + "encoding/json" + "log" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +type JoinPayload struct { + Channel string `json:"channel"` + APIKey string `json:"apiKey"` + Services []string `json:"services,omitempty"` + Levels []string `json:"levels,omitempty"` + Labels []string `json:"labels,omitempty"` +} + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, // allow all origins; customize for prod + HandshakeTimeout: 15 * time.Second, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, +} + +func SocketHandler(c *gin.Context, channels map[string]*Channel) { + // Upgrade HTTP to WebSocket + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Println("WebSocket upgrade failed:", err) + return + } + //defer conn.Close() + + // Set ping handler on the connection + conn.SetPingHandler(func(appData string) error { + log.Println("Received ping:", appData) + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // Reset read timeout + return nil // Return nil to send pong automatically + }) + + // Optional: Custom pong handler + conn.SetPongHandler(func(appData string) error { + log.Println("Received pong:", appData) + return nil + }) + + // Create new client + client := &Client{ + Conn: conn, + Send: make(chan []byte, 256), // Buffered channel + Channels: make(map[string]bool), + IPAddress: c.ClientIP(), + UserAgent: c.Request.UserAgent(), + done: make(chan struct{}), + } + + client.isAlive.Store(true) + // Add to global clients map + clientsMu.Lock() + clients[client] = true + clientsMu.Unlock() + + // Save initial connection to DB + client.SaveToDB() + // Save initial connection to DB + // if err := client.SaveToDB(); err != nil { + // log.Println("Failed to save client to DB:", err) + // conn.Close() + // return + // } + + //client.StartHeartbeat() + // Cleanup on disconnect + defer func() { + // Unregister from all channels + for channelName := range client.Channels { + if ch, exists := channels[channelName]; exists { + ch.Unregister <- client + } + } + + // Remove from global clients map + clientsMu.Lock() + delete(clients, client) + clientsMu.Unlock() + + // Mark disconnected in DB + client.MarkDisconnected() + + // Close connection + conn.Close() + log.Printf("Client disconnected: %s", client.ClientID) + }() + + client.Conn.SetPingHandler(func(appData string) error { + log.Printf("Custom ping handler for client %s", client.ClientID) + return nil + }) + + // Send welcome message immediately + welcomeMsg := map[string]string{ + "status": "connected", + "message": "Welcome to the WebSocket server. Send subscription request to begin.", + } + if err := conn.WriteJSON(welcomeMsg); err != nil { + log.Println("Failed to send welcome message:", err) + return + } + + // Message handling goroutine + go func() { + defer func() { + // Cleanup on disconnect + for channelName := range client.Channels { + if ch, exists := channels[channelName]; exists { + ch.Unregister <- client + } + } + close(client.Send) + client.MarkDisconnected() + }() + + for { + _, msg, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + log.Printf("Client disconnected unexpectedly: %v", err) + } + break + } + + var payload struct { + Channel string `json:"channel"` + APIKey string `json:"apiKey"` + Services []string `json:"services,omitempty"` + Levels []string `json:"levels,omitempty"` + Labels []string `json:"labels,omitempty"` + } + + if err := json.Unmarshal(msg, &payload); err != nil { + conn.WriteJSON(map[string]string{"error": "invalid payload format"}) + continue + } + + // Validate API key (implement your own validateAPIKey function) + // if payload.APIKey == "" || !validateAPIKey(payload.APIKey) { + // conn.WriteJSON(map[string]string{"error": "invalid or missing API key"}) + // continue + // } + + if payload.APIKey == "" { + conn.WriteMessage(websocket.TextMessage, []byte("Missing API Key")) + continue + } + client.APIKey = payload.APIKey + + // Handle channel subscription + switch payload.Channel { + case "logServices": + // Unregister from other channels if needed + if client.Channels["labels"] { + channels["labels"].Unregister <- client + delete(client.Channels, "labels") + } + + // Update client filters + client.Services = payload.Services + client.LogLevels = payload.Levels + + // Register to channel + channels["logServices"].Register <- client + client.Channels["logServices"] = true + + conn.WriteJSON(map[string]string{ + "status": "subscribed", + "channel": "logServices", + }) + + case "labels": + // Unregister from other channels if needed + if client.Channels["logServices"] { + channels["logServices"].Unregister <- client + delete(client.Channels, "logServices") + } + + // Set label filters if provided + if payload.Labels != nil { + client.Labels = payload.Labels + } + + // Register to channel + channels["labels"].Register <- client + client.Channels["labels"] = true + + // Update DB record + client.SaveToDB() + // if err := client.SaveToDB(); err != nil { + // log.Println("Failed to update client labels:", err) + // } + + conn.WriteJSON(map[string]interface{}{ + "status": "subscribed", + "channel": "labels", + "filters": client.Labels, + }) + + default: + conn.WriteJSON(map[string]string{ + "error": "invalid channel", + "available_channels": "logServices, labels", + }) + } + } + }() + + // Send messages to client + for message := range client.Send { + if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { + log.Println("Write error:", err) + break + } + } +} diff --git a/backend/main.go b/backend/main.go index bedb9ac..ec4c44f 100644 --- a/backend/main.go +++ b/backend/main.go @@ -25,8 +25,10 @@ import ( "github.com/gin-gonic/gin" "github.com/joho/godotenv" "lst.net/cmd/services/system/config" - socketio "lst.net/cmd/services/websocket" + "lst.net/cmd/services/websocket" + _ "lst.net/docs" + "lst.net/utils/db" logging "lst.net/utils/logger" ) @@ -42,7 +44,7 @@ func main() { } // Initialize DB - if err := db.InitDB(); err != nil { + if _, err := db.InitDB(); err != nil { log.Panic("Database intialize failed", "db", map[string]interface{}{ "error": err.Error(), "casue": errors.Unwrap(err), @@ -112,7 +114,7 @@ func main() { }) //logging.RegisterLoggerRoutes(r, basePath) - socketio.RegisterSocketRoutes(r) + websocket.RegisterSocketRoutes(r) config.RegisterConfigRoutes(r, basePath) r.Any(basePath+"/api", errorApiLoc) @@ -136,7 +138,7 @@ func main() { // } func errorApiLoc(c *gin.Context) { log := logging.New() - log.Info("Api endpoint hit that dose not exist", "system", map[string]interface{}{ + log.Error("Api endpoint hit that dose not exist", "system", map[string]interface{}{ "endpoint": "/api", "client_ip": c.ClientIP(), "user_agent": c.Request.UserAgent(), diff --git a/backend/utils/db/db.go b/backend/utils/db/db.go index b52f774..0a480aa 100644 --- a/backend/utils/db/db.go +++ b/backend/utils/db/db.go @@ -12,7 +12,12 @@ var DB *gorm.DB type JSONB map[string]interface{} -func InitDB() error { +type DBConfig struct { + DB *gorm.DB + DSN string +} + +func InitDB() (*DBConfig, error) { dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s", os.Getenv("DB_HOST"), os.Getenv("DB_PORT"), @@ -24,7 +29,7 @@ func InitDB() error { DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { - return fmt.Errorf("failed to connect to database: %v", err) + return nil, fmt.Errorf("failed to connect to database: %v", err) } fmt.Println("✅ Connected to database") @@ -32,12 +37,15 @@ func InitDB() error { // ensures we have the uuid stuff setup properly DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp"`) - err = DB.AutoMigrate(&Log{}, &Config{}, &ClientRecord{}) + err = DB.AutoMigrate(&Log{}, &Settings{}, &ClientRecord{}) if err != nil { - return fmt.Errorf("failed to auto-migrate models: %v", err) + return nil, fmt.Errorf("failed to auto-migrate models: %v", err) } fmt.Println("✅ Database migration completed successfully") - return nil + return &DBConfig{ + DB: DB, + DSN: dsn, + }, nil } diff --git a/backend/utils/db/config.go b/backend/utils/db/settings.go similarity index 81% rename from backend/utils/db/config.go rename to backend/utils/db/settings.go index e6a68ac..05b6687 100644 --- a/backend/utils/db/config.go +++ b/backend/utils/db/settings.go @@ -4,12 +4,12 @@ import ( "log" "time" + "github.com/google/uuid" "gorm.io/gorm" ) -type Config struct { - gorm.Model - ID uint `gorm:"primaryKey;autoIncrement"` +type Settings struct { + ConfigID uuid.UUID `gorm:"type:uuid;default:uuid_generate_v4();primaryKey" json:"id"` Name string `gorm:"uniqueIndex;not null"` Description string `gorm:"type:text"` Value string `gorm:"not null"` @@ -20,7 +20,7 @@ type Config struct { DeletedAt gorm.DeletedAt `gorm:"index"` } -var seedConfigData = []Config{ +var seedConfigData = []Settings{ {Name: "serverPort", Description: "The port the server will listen on if not running in docker", Value: "4000", Enabled: true}, {Name: "server", Description: "The server we will use when connecting to the alplaprod sql", Value: "usmcd1vms006", Enabled: true}, } @@ -28,7 +28,7 @@ var seedConfigData = []Config{ func SeedConfigs(db *gorm.DB) error { for _, cfg := range seedConfigData { - var existing Config + var existing Settings // Try to find config by unique name result := db.Where("name =?", cfg.Name).First(&existing) @@ -57,11 +57,11 @@ func SeedConfigs(db *gorm.DB) error { return nil } -func GetAllConfigs(db *gorm.DB) ([]Config, error) { - var configs []Config +func GetAllConfigs(db *gorm.DB) ([]Settings, error) { + var settings []Settings - result := db.Find(&configs) + result := db.Find(&settings) - return configs, result.Error + return settings, result.Error } diff --git a/backend/utils/logger/logger.go b/backend/utils/logger/logger.go index 1391f43..2e3b5e8 100644 --- a/backend/utils/logger/logger.go +++ b/backend/utils/logger/logger.go @@ -17,8 +17,9 @@ type CustomLogger struct { } type Message struct { - Channel string `json:"channel"` - Data interface{} `json:"data"` + Channel string `json:"channel"` + Data map[string]interface{} `json:"data"` + Meta map[string]interface{} `json:"meta,omitempty"` } // New creates a configured CustomLogger.