refactor(correction to folder sturcture): before we got to deep resturctures to best pactice folder
This commit is contained in:
179
backend/internal/notifications/ws/ws_channel_manager.go
Normal file
179
backend/internal/notifications/ws/ws_channel_manager.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"lst.net/pkg/logger"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
Name string
|
||||
Clients map[*Client]bool
|
||||
Register chan *Client
|
||||
Unregister chan *Client
|
||||
Broadcast chan []byte
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
channels = make(map[string]*Channel)
|
||||
channelsMu sync.RWMutex
|
||||
)
|
||||
|
||||
// InitializeChannels creates and returns all channels
|
||||
func InitializeChannels() {
|
||||
channelsMu.Lock()
|
||||
defer channelsMu.Unlock()
|
||||
|
||||
channels["logServices"] = NewChannel("logServices")
|
||||
channels["labels"] = NewChannel("labels")
|
||||
// Add more channels here as needed
|
||||
}
|
||||
|
||||
func NewChannel(name string) *Channel {
|
||||
return &Channel{
|
||||
Name: name,
|
||||
Clients: make(map[*Client]bool),
|
||||
Register: make(chan *Client),
|
||||
Unregister: make(chan *Client),
|
||||
Broadcast: make(chan []byte),
|
||||
}
|
||||
}
|
||||
|
||||
func GetChannel(name string) (*Channel, bool) {
|
||||
channelsMu.RLock()
|
||||
defer channelsMu.RUnlock()
|
||||
ch, exists := channels[name]
|
||||
return ch, exists
|
||||
}
|
||||
|
||||
func GetAllChannels() map[string]*Channel {
|
||||
channelsMu.RLock()
|
||||
defer channelsMu.RUnlock()
|
||||
|
||||
chs := make(map[string]*Channel)
|
||||
for k, v := range channels {
|
||||
chs[k] = v
|
||||
}
|
||||
return chs
|
||||
}
|
||||
|
||||
func StartAllChannels() {
|
||||
|
||||
channelsMu.RLock()
|
||||
defer channelsMu.RUnlock()
|
||||
|
||||
for _, ch := range channels {
|
||||
go ch.RunChannel()
|
||||
}
|
||||
}
|
||||
|
||||
func CleanupChannels() {
|
||||
channelsMu.Lock()
|
||||
defer channelsMu.Unlock()
|
||||
|
||||
for _, ch := range channels {
|
||||
close(ch.Broadcast)
|
||||
// Add any other cleanup needed
|
||||
}
|
||||
channels = make(map[string]*Channel)
|
||||
}
|
||||
|
||||
func StartBroadcasting(broadcaster chan logger.Message, channels map[string]*Channel) {
|
||||
logger := logger.New()
|
||||
go func() {
|
||||
for msg := range broadcaster {
|
||||
switch msg.Channel {
|
||||
case "logServices":
|
||||
// Just forward the message - filtering happens in RunChannel()
|
||||
messageBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling message", "websocket", map[string]interface{}{
|
||||
"errors": err,
|
||||
})
|
||||
continue
|
||||
}
|
||||
channels["logServices"].Broadcast <- messageBytes
|
||||
|
||||
case "labels":
|
||||
// Future labels handling
|
||||
messageBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling message", "websocket", map[string]interface{}{
|
||||
"errors": err,
|
||||
})
|
||||
continue
|
||||
}
|
||||
channels["labels"].Broadcast <- messageBytes
|
||||
|
||||
default:
|
||||
log.Printf("Received message for unknown channel: %s", msg.Channel)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func contains(slice []string, item string) bool {
|
||||
// Empty filter slice means "match all"
|
||||
if len(slice) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Case-insensitive comparison
|
||||
item = strings.ToLower(item)
|
||||
for _, s := range slice {
|
||||
if strings.ToLower(s) == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Updated Channel.RunChannel() for logServices filtering
|
||||
func (ch *Channel) RunChannel() {
|
||||
for {
|
||||
select {
|
||||
case client := <-ch.Register:
|
||||
ch.lock.Lock()
|
||||
ch.Clients[client] = true
|
||||
ch.lock.Unlock()
|
||||
|
||||
case client := <-ch.Unregister:
|
||||
ch.lock.Lock()
|
||||
delete(ch.Clients, client)
|
||||
ch.lock.Unlock()
|
||||
|
||||
case message := <-ch.Broadcast:
|
||||
var msg logger.Message
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ch.lock.RLock()
|
||||
for client := range ch.Clients {
|
||||
// Special filtering for logServices
|
||||
if ch.Name == "logServices" {
|
||||
logLevel, _ := msg.Meta["level"].(string)
|
||||
logService, _ := msg.Meta["service"].(string)
|
||||
|
||||
levelMatch := len(client.LogLevels) == 0 || contains(client.LogLevels, logLevel)
|
||||
serviceMatch := len(client.Services) == 0 || contains(client.Services, logService)
|
||||
|
||||
if !levelMatch || !serviceMatch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case client.Send <- message:
|
||||
default:
|
||||
ch.Unregister <- client
|
||||
}
|
||||
}
|
||||
ch.lock.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
275
backend/internal/notifications/ws/ws_client.go
Normal file
275
backend/internal/notifications/ws/ws_client.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"lst.net/internal/db"
|
||||
"lst.net/internal/models"
|
||||
"lst.net/pkg"
|
||||
"lst.net/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
clients = make(map[*Client]bool)
|
||||
clientsMu sync.RWMutex
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
ClientID uuid.UUID `json:"client_id"`
|
||||
Conn *websocket.Conn `json:"-"` // Excluded from JSON
|
||||
APIKey string `json:"api_key"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Send chan []byte `json:"-"` // Excluded from JSON
|
||||
Channels map[string]bool `json:"channels"`
|
||||
LogLevels []string `json:"levels,omitempty"`
|
||||
Services []string `json:"services,omitempty"`
|
||||
Labels []string `json:"labels,omitempty"`
|
||||
ConnectedAt time.Time `json:"connected_at"`
|
||||
done chan struct{} // For graceful shutdown
|
||||
isAlive atomic.Bool
|
||||
lastActive time.Time // Tracks last activity
|
||||
|
||||
}
|
||||
|
||||
func (c *Client) SaveToDB() {
|
||||
// Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB
|
||||
channels := make(map[string]interface{})
|
||||
for ch := range c.Channels {
|
||||
channels[ch] = true
|
||||
}
|
||||
|
||||
clientRecord := &models.ClientRecord{
|
||||
APIKey: c.APIKey,
|
||||
IPAddress: c.IPAddress,
|
||||
UserAgent: c.UserAgent,
|
||||
Channels: pkg.JSONB(channels),
|
||||
ConnectedAt: time.Now(),
|
||||
LastHeartbeat: time.Now(),
|
||||
}
|
||||
|
||||
if err := db.DB.Create(&clientRecord).Error; err != nil {
|
||||
log.Println("❌ Error saving client:", err)
|
||||
|
||||
} else {
|
||||
c.ClientID = clientRecord.ClientID
|
||||
c.ConnectedAt = clientRecord.ConnectedAt
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) MarkDisconnected() {
|
||||
logger := logger.New()
|
||||
clientData := fmt.Sprintf("Client %v just lefts us", c.ClientID)
|
||||
logger.Info(clientData, "websocket", map[string]interface{}{})
|
||||
|
||||
now := time.Now()
|
||||
res := db.DB.Model(&models.ClientRecord{}).
|
||||
Where("client_id = ?", c.ClientID).
|
||||
Updates(map[string]interface{}{
|
||||
"disconnected_at": &now,
|
||||
})
|
||||
|
||||
if res.RowsAffected == 0 {
|
||||
|
||||
logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{
|
||||
"clientID": c.ClientID,
|
||||
})
|
||||
}
|
||||
if res.Error != nil {
|
||||
|
||||
logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{
|
||||
"clientID": c.ClientID,
|
||||
"error": res.Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SafeClient() *Client {
|
||||
return &Client{
|
||||
ClientID: c.ClientID,
|
||||
APIKey: c.APIKey,
|
||||
IPAddress: c.IPAddress,
|
||||
UserAgent: c.UserAgent,
|
||||
Channels: c.Channels,
|
||||
LogLevels: c.LogLevels,
|
||||
Services: c.Services,
|
||||
Labels: c.Labels,
|
||||
ConnectedAt: c.ConnectedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllClients returns safe representations of all clients
|
||||
func GetAllClients() []*Client {
|
||||
clientsMu.RLock()
|
||||
defer clientsMu.RUnlock()
|
||||
|
||||
var clientList []*Client
|
||||
for client := range clients {
|
||||
clientList = append(clientList, client.SafeClient())
|
||||
}
|
||||
return clientList
|
||||
}
|
||||
|
||||
// GetClientsByChannel returns clients in a specific channel
|
||||
func GetClientsByChannel(channel string) []*Client {
|
||||
clientsMu.RLock()
|
||||
defer clientsMu.RUnlock()
|
||||
|
||||
var channelClients []*Client
|
||||
for client := range clients {
|
||||
if client.Channels[channel] {
|
||||
channelClients = append(channelClients, client.SafeClient())
|
||||
}
|
||||
}
|
||||
return channelClients
|
||||
}
|
||||
|
||||
// heat beat stuff
|
||||
const (
|
||||
pingPeriod = 30 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
writeWait = 10 * time.Second
|
||||
)
|
||||
|
||||
func (c *Client) StartHeartbeat() {
|
||||
logger := logger.New()
|
||||
log.Println("Started hearbeat")
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !c.isAlive.Load() { // Correct way to read atomic.Bool
|
||||
return
|
||||
}
|
||||
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("Heartbeat failed for %s: %v", c.ClientID, err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
res := db.DB.Model(&models.ClientRecord{}).
|
||||
Where("client_id = ?", c.ClientID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_heartbeat": &now,
|
||||
})
|
||||
|
||||
if res.RowsAffected == 0 {
|
||||
|
||||
logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{
|
||||
"clientID": c.ClientID,
|
||||
})
|
||||
}
|
||||
if res.Error != nil {
|
||||
|
||||
logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{
|
||||
"clientID": c.ClientID,
|
||||
"error": res.Error,
|
||||
})
|
||||
}
|
||||
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
if c.isAlive.CompareAndSwap(true, false) { // Atomic swap
|
||||
close(c.done)
|
||||
c.Conn.Close()
|
||||
// Add any other cleanup here
|
||||
c.MarkDisconnected()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) startServerPings() {
|
||||
ticker := time.NewTicker(60 * time.Second) // Ping every 30s
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
c.Close() // Disconnect if ping fails
|
||||
return
|
||||
}
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) markActive() {
|
||||
c.lastActive = time.Now() // No mutex needed if single-writer
|
||||
}
|
||||
|
||||
func (c *Client) IsActive() bool {
|
||||
return time.Since(c.lastActive) < 45*time.Second // 1.5x ping interval
|
||||
}
|
||||
|
||||
func (c *Client) updateHeartbeat() {
|
||||
//fmt.Println("Updating heatbeat")
|
||||
now := time.Now()
|
||||
logger := logger.New()
|
||||
|
||||
//fmt.Printf("Updating heartbeat for client: %s at %v\n", c.ClientID, now)
|
||||
|
||||
//db.DB = db.DB.Debug()
|
||||
res := db.DB.Model(&models.ClientRecord{}).
|
||||
Where("client_id = ?", c.ClientID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_heartbeat": &now, // Explicit format
|
||||
})
|
||||
//fmt.Printf("Executed SQL: %v\n", db.DB.Statement.SQL.String())
|
||||
if res.RowsAffected == 0 {
|
||||
|
||||
logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{
|
||||
"clientID": c.ClientID,
|
||||
})
|
||||
}
|
||||
if res.Error != nil {
|
||||
|
||||
logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{
|
||||
"clientID": c.ClientID,
|
||||
"error": res.Error,
|
||||
})
|
||||
}
|
||||
// 2. Verify DB connection
|
||||
if db.DB == nil {
|
||||
logger.Error("DB connection is nil", "websocket", map[string]interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Test raw SQL execution first
|
||||
testRes := db.DB.Exec("SELECT 1")
|
||||
if testRes.Error != nil {
|
||||
logger.Error("DB ping failed", "websocket", map[string]interface{}{
|
||||
"error": testRes.Error,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// work on this stats later
|
||||
// Add to your admin endpoint
|
||||
// type ConnectionStats struct {
|
||||
// TotalConnections int `json:"total_connections"`
|
||||
// ActiveConnections int `json:"active_connections"`
|
||||
// AvgDuration string `json:"avg_duration"`
|
||||
// }
|
||||
|
||||
// func GetConnectionStats() ConnectionStats {
|
||||
// // Implement your metrics tracking
|
||||
// }
|
||||
224
backend/internal/notifications/ws/ws_handler.go
Normal file
224
backend/internal/notifications/ws/ws_handler.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type JoinPayload struct {
|
||||
Channel string `json:"channel"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Services []string `json:"services,omitempty"`
|
||||
Levels []string `json:"levels,omitempty"`
|
||||
Labels []string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true }, // allow all origins; customize for prod
|
||||
HandshakeTimeout: 15 * time.Second,
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
EnableCompression: true,
|
||||
}
|
||||
|
||||
func SocketHandler(c *gin.Context, channels map[string]*Channel) {
|
||||
// Upgrade HTTP to WebSocket
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Println("WebSocket upgrade failed:", err)
|
||||
return
|
||||
}
|
||||
//defer conn.Close()
|
||||
|
||||
// Create new client
|
||||
client := &Client{
|
||||
Conn: conn,
|
||||
APIKey: "exampleAPIKey",
|
||||
Send: make(chan []byte, 256), // Buffered channel
|
||||
Channels: make(map[string]bool),
|
||||
IPAddress: c.ClientIP(),
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
client.isAlive.Store(true)
|
||||
// Add to global clients map
|
||||
clientsMu.Lock()
|
||||
clients[client] = true
|
||||
clientsMu.Unlock()
|
||||
|
||||
// Save initial connection to DB
|
||||
client.SaveToDB()
|
||||
// Save initial connection to DB
|
||||
// if err := client.SaveToDB(); err != nil {
|
||||
// log.Println("Failed to save client to DB:", err)
|
||||
// conn.Close()
|
||||
// return
|
||||
// }
|
||||
|
||||
// Set handlers
|
||||
conn.SetPingHandler(func(string) error {
|
||||
return nil // Auto-responds with pong
|
||||
})
|
||||
|
||||
conn.SetPongHandler(func(string) error {
|
||||
now := time.Now()
|
||||
client.markActive() // Track last pong time
|
||||
client.lastActive = now
|
||||
client.updateHeartbeat()
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start server-side ping ticker
|
||||
go client.startServerPings()
|
||||
|
||||
defer func() {
|
||||
// Unregister from all channels
|
||||
for channelName := range client.Channels {
|
||||
if ch, exists := channels[channelName]; exists {
|
||||
ch.Unregister <- client
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from global clients map
|
||||
clientsMu.Lock()
|
||||
delete(clients, client)
|
||||
clientsMu.Unlock()
|
||||
|
||||
// Mark disconnected in DB
|
||||
client.MarkDisconnected()
|
||||
|
||||
// Close connection
|
||||
conn.Close()
|
||||
log.Printf("Client disconnected: %s", client.ClientID)
|
||||
}()
|
||||
|
||||
// Send welcome message immediately
|
||||
welcomeMsg := map[string]string{
|
||||
"status": "connected",
|
||||
"message": "Welcome to the WebSocket server. Send subscription request to begin.",
|
||||
}
|
||||
if err := conn.WriteJSON(welcomeMsg); err != nil {
|
||||
log.Println("Failed to send welcome message:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Message handling goroutine
|
||||
go func() {
|
||||
defer func() {
|
||||
// Cleanup on disconnect
|
||||
for channelName := range client.Channels {
|
||||
if ch, exists := channels[channelName]; exists {
|
||||
ch.Unregister <- client
|
||||
}
|
||||
}
|
||||
close(client.Send)
|
||||
client.MarkDisconnected()
|
||||
}()
|
||||
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
||||
log.Printf("Client disconnected unexpectedly: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Channel string `json:"channel"`
|
||||
APIKey string `json:"apiKey"`
|
||||
Services []string `json:"services,omitempty"`
|
||||
Levels []string `json:"levels,omitempty"`
|
||||
Labels []string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(msg, &payload); err != nil {
|
||||
conn.WriteJSON(map[string]string{"error": "invalid payload format"})
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate API key (implement your own validateAPIKey function)
|
||||
// if payload.APIKey == "" || !validateAPIKey(payload.APIKey) {
|
||||
// conn.WriteJSON(map[string]string{"error": "invalid or missing API key"})
|
||||
// continue
|
||||
// }
|
||||
|
||||
if payload.APIKey == "" {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte("Missing API Key"))
|
||||
continue
|
||||
}
|
||||
client.APIKey = payload.APIKey
|
||||
|
||||
// Handle channel subscription
|
||||
switch payload.Channel {
|
||||
case "logServices":
|
||||
// Unregister from other channels if needed
|
||||
if client.Channels["labels"] {
|
||||
channels["labels"].Unregister <- client
|
||||
delete(client.Channels, "labels")
|
||||
}
|
||||
|
||||
// Update client filters
|
||||
client.Services = payload.Services
|
||||
client.LogLevels = payload.Levels
|
||||
|
||||
// Register to channel
|
||||
channels["logServices"].Register <- client
|
||||
client.Channels["logServices"] = true
|
||||
|
||||
conn.WriteJSON(map[string]string{
|
||||
"status": "subscribed",
|
||||
"channel": "logServices",
|
||||
})
|
||||
|
||||
case "labels":
|
||||
// Unregister from other channels if needed
|
||||
if client.Channels["logServices"] {
|
||||
channels["logServices"].Unregister <- client
|
||||
delete(client.Channels, "logServices")
|
||||
}
|
||||
|
||||
// Set label filters if provided
|
||||
if payload.Labels != nil {
|
||||
client.Labels = payload.Labels
|
||||
}
|
||||
|
||||
// Register to channel
|
||||
channels["labels"].Register <- client
|
||||
client.Channels["labels"] = true
|
||||
|
||||
// Update DB record
|
||||
client.SaveToDB()
|
||||
// if err := client.SaveToDB(); err != nil {
|
||||
// log.Println("Failed to update client labels:", err)
|
||||
// }
|
||||
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"status": "subscribed",
|
||||
"channel": "labels",
|
||||
"filters": client.Labels,
|
||||
})
|
||||
|
||||
default:
|
||||
conn.WriteJSON(map[string]string{
|
||||
"error": "invalid channel",
|
||||
"available_channels": "logServices, labels",
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Send messages to client
|
||||
for message := range client.Send {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
log.Println("Write error:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
80
backend/internal/notifications/ws/ws_log_service.go
Normal file
80
backend/internal/notifications/ws/ws_log_service.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package ws
|
||||
|
||||
// setup the notifiyer
|
||||
|
||||
// -- Only needs to be run once in DB
|
||||
// CREATE OR REPLACE FUNCTION notify_new_log() RETURNS trigger AS $$
|
||||
// BEGIN
|
||||
// PERFORM pg_notify('new_log', row_to_json(NEW)::text);
|
||||
// RETURN NEW;
|
||||
// END;
|
||||
// $$ LANGUAGE plpgsql;
|
||||
|
||||
// DROP TRIGGER IF EXISTS new_log_trigger ON logs;
|
||||
|
||||
// CREATE TRIGGER new_log_trigger
|
||||
// AFTER INSERT ON logs
|
||||
// FOR EACH ROW EXECUTE FUNCTION notify_new_log();
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"lst.net/pkg/logger"
|
||||
)
|
||||
|
||||
func LogServices(broadcaster chan logger.Message) {
|
||||
log := logger.New()
|
||||
|
||||
log.Info("[LogServices] started - single channel for all logs", "websocket", map[string]interface{}{})
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
||||
os.Getenv("DB_HOST"),
|
||||
os.Getenv("DB_PORT"),
|
||||
os.Getenv("DB_USER"),
|
||||
os.Getenv("DB_PASSWORD"),
|
||||
os.Getenv("DB_NAME"),
|
||||
)
|
||||
|
||||
listener := pq.NewListener(dsn, 10*time.Second, time.Minute, nil)
|
||||
err := listener.Listen("new_log")
|
||||
if err != nil {
|
||||
log.Panic("Failed to LISTEN on new_log", "logger", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println("Listening for all logs through single logServices channel...")
|
||||
for {
|
||||
select {
|
||||
case notify := <-listener.Notify:
|
||||
if notify != nil {
|
||||
var logData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(notify.Extra), &logData); err != nil {
|
||||
log.Error("Failed to unmarshal notification payload", "logger", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Always send to logServices channel
|
||||
broadcaster <- logger.Message{
|
||||
Channel: "logServices",
|
||||
Data: logData,
|
||||
Meta: map[string]interface{}{
|
||||
"level": logData["level"],
|
||||
"service": logData["service"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
case <-time.After(90 * time.Second):
|
||||
go func() {
|
||||
listener.Ping()
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
55
backend/internal/notifications/ws/ws_routes.go
Normal file
55
backend/internal/notifications/ws/ws_routes.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"lst.net/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
broadcaster = make(chan logger.Message)
|
||||
)
|
||||
|
||||
func RegisterSocketRoutes(r *gin.Engine, base_url string) {
|
||||
// Initialize all channels
|
||||
InitializeChannels()
|
||||
|
||||
// Start channel processors
|
||||
StartAllChannels()
|
||||
|
||||
// Start background services
|
||||
go LogServices(broadcaster)
|
||||
go StartBroadcasting(broadcaster, channels)
|
||||
|
||||
// WebSocket route
|
||||
r.GET(base_url+"/ws", func(c *gin.Context) {
|
||||
SocketHandler(c, channels)
|
||||
})
|
||||
|
||||
r.GET(base_url+"/ws/clients", AdminAuthMiddleware(), handleGetClients)
|
||||
}
|
||||
|
||||
func handleGetClients(c *gin.Context) {
|
||||
channel := c.Query("channel")
|
||||
|
||||
var clientList []*Client
|
||||
if channel != "" {
|
||||
clientList = GetClientsByChannel(channel)
|
||||
} else {
|
||||
clientList = GetAllClients()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"count": len(clientList),
|
||||
"clients": clientList,
|
||||
})
|
||||
}
|
||||
|
||||
func AdminAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Implement your admin authentication logic
|
||||
// Example: Check API key or JWT token
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user