refactor(ws): ws logging and channel manager added no auth currently
This commit is contained in:
168
backend/cmd/services/websocket/ws_client.go
Normal file
168
backend/cmd/services/websocket/ws_client.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"lst.net/utils/db"
|
||||
)
|
||||
|
||||
var (
|
||||
clients = make(map[*Client]bool)
|
||||
clientsMu sync.RWMutex
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
ClientID uuid.UUID `json:"client_id"`
|
||||
Conn *websocket.Conn `json:"-"` // Excluded from JSON
|
||||
APIKey string `json:"api_key"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Send chan []byte `json:"-"` // Excluded from JSON
|
||||
Channels map[string]bool `json:"channels"`
|
||||
LogLevels []string `json:"levels,omitempty"`
|
||||
Services []string `json:"services,omitempty"`
|
||||
Labels []string `json:"labels,omitempty"`
|
||||
ConnectedAt time.Time `json:"connected_at"`
|
||||
done chan struct{} // For graceful shutdown
|
||||
isAlive atomic.Bool
|
||||
//mu sync.Mutex // Protects isAlive if not using atomic
|
||||
}
|
||||
|
||||
func (c *Client) SaveToDB() {
|
||||
// Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB
|
||||
channels := make(map[string]interface{})
|
||||
for ch := range c.Channels {
|
||||
channels[ch] = true
|
||||
}
|
||||
|
||||
clientRecord := &db.ClientRecord{
|
||||
APIKey: c.APIKey,
|
||||
IPAddress: c.IPAddress,
|
||||
UserAgent: c.UserAgent,
|
||||
Channels: db.JSONB(channels),
|
||||
ConnectedAt: time.Now(),
|
||||
LastHeartbeat: time.Now(),
|
||||
}
|
||||
|
||||
if err := db.DB.Create(&clientRecord).Error; err != nil {
|
||||
log.Println("❌ Error saving client:", err)
|
||||
|
||||
} else {
|
||||
c.ClientID = clientRecord.ClientID
|
||||
c.ConnectedAt = clientRecord.ConnectedAt
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) MarkDisconnected() {
|
||||
log.Printf("Client %v just lefts us", c.ClientID)
|
||||
now := time.Now()
|
||||
res := db.DB.Model(&db.ClientRecord{}).
|
||||
Where("client_id = ?", c.ClientID).
|
||||
Updates(map[string]interface{}{
|
||||
"disconnected_at": &now,
|
||||
})
|
||||
|
||||
if res.RowsAffected == 0 {
|
||||
log.Println("⚠️ No rows updated for client_id:", c.ClientID)
|
||||
}
|
||||
if res.Error != nil {
|
||||
log.Println("❌ Error updating disconnected_at:", res.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SafeClient() *Client {
|
||||
return &Client{
|
||||
ClientID: c.ClientID,
|
||||
APIKey: c.APIKey,
|
||||
IPAddress: c.IPAddress,
|
||||
UserAgent: c.UserAgent,
|
||||
Channels: c.Channels,
|
||||
LogLevels: c.LogLevels,
|
||||
Services: c.Services,
|
||||
Labels: c.Labels,
|
||||
ConnectedAt: c.ConnectedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllClients returns safe representations of all clients
|
||||
func GetAllClients() []*Client {
|
||||
clientsMu.RLock()
|
||||
defer clientsMu.RUnlock()
|
||||
|
||||
var clientList []*Client
|
||||
for client := range clients {
|
||||
clientList = append(clientList, client.SafeClient())
|
||||
}
|
||||
return clientList
|
||||
}
|
||||
|
||||
// GetClientsByChannel returns clients in a specific channel
|
||||
func GetClientsByChannel(channel string) []*Client {
|
||||
clientsMu.RLock()
|
||||
defer clientsMu.RUnlock()
|
||||
|
||||
var channelClients []*Client
|
||||
for client := range clients {
|
||||
if client.Channels[channel] {
|
||||
channelClients = append(channelClients, client.SafeClient())
|
||||
}
|
||||
}
|
||||
return channelClients
|
||||
}
|
||||
|
||||
// heat beat stuff
|
||||
const (
|
||||
pingPeriod = 30 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
writeWait = 10 * time.Second
|
||||
)
|
||||
|
||||
func (c *Client) StartHeartbeat() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !c.isAlive.Load() { // Correct way to read atomic.Bool
|
||||
return
|
||||
}
|
||||
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("Heartbeat failed for %s: %v", c.ClientID, err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
if c.isAlive.CompareAndSwap(true, false) { // Atomic swap
|
||||
close(c.done)
|
||||
c.Conn.Close()
|
||||
// Add any other cleanup here
|
||||
c.MarkDisconnected()
|
||||
}
|
||||
}
|
||||
|
||||
// work on this stats later
|
||||
// Add to your admin endpoint
|
||||
// type ConnectionStats struct {
|
||||
// TotalConnections int `json:"total_connections"`
|
||||
// ActiveConnections int `json:"active_connections"`
|
||||
// AvgDuration string `json:"avg_duration"`
|
||||
// }
|
||||
|
||||
// func GetConnectionStats() ConnectionStats {
|
||||
// // Implement your metrics tracking
|
||||
// }
|
||||
Reference in New Issue
Block a user