274 lines
6.5 KiB
Go
274 lines
6.5 KiB
Go
package websocket
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"lst.net/utils/db"
|
|
logging "lst.net/utils/logger"
|
|
)
|
|
|
|
var (
|
|
clients = make(map[*Client]bool)
|
|
clientsMu sync.RWMutex
|
|
)
|
|
|
|
type Client struct {
|
|
ClientID uuid.UUID `json:"client_id"`
|
|
Conn *websocket.Conn `json:"-"` // Excluded from JSON
|
|
APIKey string `json:"api_key"`
|
|
IPAddress string `json:"ip_address"`
|
|
UserAgent string `json:"user_agent"`
|
|
Send chan []byte `json:"-"` // Excluded from JSON
|
|
Channels map[string]bool `json:"channels"`
|
|
LogLevels []string `json:"levels,omitempty"`
|
|
Services []string `json:"services,omitempty"`
|
|
Labels []string `json:"labels,omitempty"`
|
|
ConnectedAt time.Time `json:"connected_at"`
|
|
done chan struct{} // For graceful shutdown
|
|
isAlive atomic.Bool
|
|
lastActive time.Time // Tracks last activity
|
|
|
|
}
|
|
|
|
func (c *Client) SaveToDB() {
|
|
// Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB
|
|
channels := make(map[string]interface{})
|
|
for ch := range c.Channels {
|
|
channels[ch] = true
|
|
}
|
|
|
|
clientRecord := &db.ClientRecord{
|
|
APIKey: c.APIKey,
|
|
IPAddress: c.IPAddress,
|
|
UserAgent: c.UserAgent,
|
|
Channels: db.JSONB(channels),
|
|
ConnectedAt: time.Now(),
|
|
LastHeartbeat: time.Now(),
|
|
}
|
|
|
|
if err := db.DB.Create(&clientRecord).Error; err != nil {
|
|
log.Println("❌ Error saving client:", err)
|
|
|
|
} else {
|
|
c.ClientID = clientRecord.ClientID
|
|
c.ConnectedAt = clientRecord.ConnectedAt
|
|
}
|
|
}
|
|
|
|
func (c *Client) MarkDisconnected() {
|
|
logger := logging.New()
|
|
clientData := fmt.Sprintf("Client %v just lefts us", c.ClientID)
|
|
logger.Info(clientData, "websocket", map[string]interface{}{})
|
|
|
|
now := time.Now()
|
|
res := db.DB.Model(&db.ClientRecord{}).
|
|
Where("client_id = ?", c.ClientID).
|
|
Updates(map[string]interface{}{
|
|
"disconnected_at": &now,
|
|
})
|
|
|
|
if res.RowsAffected == 0 {
|
|
|
|
logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{
|
|
"clientID": c.ClientID,
|
|
})
|
|
}
|
|
if res.Error != nil {
|
|
|
|
logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{
|
|
"clientID": c.ClientID,
|
|
"error": res.Error,
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Client) SafeClient() *Client {
|
|
return &Client{
|
|
ClientID: c.ClientID,
|
|
APIKey: c.APIKey,
|
|
IPAddress: c.IPAddress,
|
|
UserAgent: c.UserAgent,
|
|
Channels: c.Channels,
|
|
LogLevels: c.LogLevels,
|
|
Services: c.Services,
|
|
Labels: c.Labels,
|
|
ConnectedAt: c.ConnectedAt,
|
|
}
|
|
}
|
|
|
|
// GetAllClients returns safe representations of all clients
|
|
func GetAllClients() []*Client {
|
|
clientsMu.RLock()
|
|
defer clientsMu.RUnlock()
|
|
|
|
var clientList []*Client
|
|
for client := range clients {
|
|
clientList = append(clientList, client.SafeClient())
|
|
}
|
|
return clientList
|
|
}
|
|
|
|
// GetClientsByChannel returns clients in a specific channel
|
|
func GetClientsByChannel(channel string) []*Client {
|
|
clientsMu.RLock()
|
|
defer clientsMu.RUnlock()
|
|
|
|
var channelClients []*Client
|
|
for client := range clients {
|
|
if client.Channels[channel] {
|
|
channelClients = append(channelClients, client.SafeClient())
|
|
}
|
|
}
|
|
return channelClients
|
|
}
|
|
|
|
// heat beat stuff
|
|
const (
|
|
pingPeriod = 30 * time.Second
|
|
pongWait = 60 * time.Second
|
|
writeWait = 10 * time.Second
|
|
)
|
|
|
|
func (c *Client) StartHeartbeat() {
|
|
logger := logging.New()
|
|
log.Println("Started hearbeat")
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if !c.isAlive.Load() { // Correct way to read atomic.Bool
|
|
return
|
|
}
|
|
|
|
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
log.Printf("Heartbeat failed for %s: %v", c.ClientID, err)
|
|
c.Close()
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
res := db.DB.Model(&db.ClientRecord{}).
|
|
Where("client_id = ?", c.ClientID).
|
|
Updates(map[string]interface{}{
|
|
"last_heartbeat": &now,
|
|
})
|
|
|
|
if res.RowsAffected == 0 {
|
|
|
|
logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{
|
|
"clientID": c.ClientID,
|
|
})
|
|
}
|
|
if res.Error != nil {
|
|
|
|
logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{
|
|
"clientID": c.ClientID,
|
|
"error": res.Error,
|
|
})
|
|
}
|
|
|
|
case <-c.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) Close() {
|
|
if c.isAlive.CompareAndSwap(true, false) { // Atomic swap
|
|
close(c.done)
|
|
c.Conn.Close()
|
|
// Add any other cleanup here
|
|
c.MarkDisconnected()
|
|
}
|
|
}
|
|
|
|
func (c *Client) startServerPings() {
|
|
ticker := time.NewTicker(60 * time.Second) // Ping every 30s
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
c.Close() // Disconnect if ping fails
|
|
return
|
|
}
|
|
case <-c.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) markActive() {
|
|
c.lastActive = time.Now() // No mutex needed if single-writer
|
|
}
|
|
|
|
func (c *Client) IsActive() bool {
|
|
return time.Since(c.lastActive) < 45*time.Second // 1.5x ping interval
|
|
}
|
|
|
|
func (c *Client) updateHeartbeat() {
|
|
//fmt.Println("Updating heatbeat")
|
|
now := time.Now()
|
|
logger := logging.New()
|
|
|
|
//fmt.Printf("Updating heartbeat for client: %s at %v\n", c.ClientID, now)
|
|
|
|
//db.DB = db.DB.Debug()
|
|
res := db.DB.Model(&db.ClientRecord{}).
|
|
Where("client_id = ?", c.ClientID).
|
|
Updates(map[string]interface{}{
|
|
"last_heartbeat": &now, // Explicit format
|
|
})
|
|
//fmt.Printf("Executed SQL: %v\n", db.DB.Statement.SQL.String())
|
|
if res.RowsAffected == 0 {
|
|
|
|
logger.Info("⚠️ No rows updated for client_id", "websocket", map[string]interface{}{
|
|
"clientID": c.ClientID,
|
|
})
|
|
}
|
|
if res.Error != nil {
|
|
|
|
logger.Error("❌ Error updating disconnected_at", "websocket", map[string]interface{}{
|
|
"clientID": c.ClientID,
|
|
"error": res.Error,
|
|
})
|
|
}
|
|
// 2. Verify DB connection
|
|
if db.DB == nil {
|
|
logger.Error("DB connection is nil", "websocket", map[string]interface{}{})
|
|
return
|
|
}
|
|
|
|
// 3. Test raw SQL execution first
|
|
testRes := db.DB.Exec("SELECT 1")
|
|
if testRes.Error != nil {
|
|
logger.Error("DB ping failed", "websocket", map[string]interface{}{
|
|
"error": testRes.Error,
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
// work on this stats later
|
|
// Add to your admin endpoint
|
|
// type ConnectionStats struct {
|
|
// TotalConnections int `json:"total_connections"`
|
|
// ActiveConnections int `json:"active_connections"`
|
|
// AvgDuration string `json:"avg_duration"`
|
|
// }
|
|
|
|
// func GetConnectionStats() ConnectionStats {
|
|
// // Implement your metrics tracking
|
|
// }
|