164 lines
3.7 KiB
Go
164 lines
3.7 KiB
Go
package socketio
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"lst.net/utils/db"
|
|
)
|
|
|
|
type JoinPayload struct {
|
|
Channel string `json:"channel"`
|
|
Services []string `json:"services,omitempty"`
|
|
APIKey string `json:"apiKey"`
|
|
}
|
|
|
|
// type Channel struct {
|
|
// Name string
|
|
// Clients map[*Client]bool
|
|
// Register chan *Client
|
|
// Unregister chan *Client
|
|
// Broadcast chan Message
|
|
// }
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true }, // allow all origins; customize for prod
|
|
}
|
|
|
|
func SocketHandler(c *gin.Context) {
|
|
// Upgrade HTTP to websocket
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
log.Println("Failed to upgrade:", err)
|
|
c.AbortWithStatus(http.StatusBadRequest)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Create client struct
|
|
client := &Client{
|
|
Conn: conn,
|
|
IPAddress: c.ClientIP(),
|
|
UserAgent: c.Request.UserAgent(),
|
|
Channels: make(map[string]bool),
|
|
}
|
|
|
|
clientsLock.Lock()
|
|
clients[client] = true
|
|
clientsLock.Unlock()
|
|
|
|
defer func() {
|
|
clientsLock.Lock()
|
|
delete(clients, client)
|
|
clientsLock.Unlock()
|
|
client.MarkDisconnected()
|
|
client.Disconnect()
|
|
conn.Close()
|
|
}()
|
|
|
|
for {
|
|
// Read message from client
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
log.Println("Read error:", err)
|
|
clientsLock.Lock()
|
|
delete(clients, client)
|
|
clientsLock.Unlock()
|
|
client.MarkDisconnected()
|
|
client.Disconnect()
|
|
break
|
|
}
|
|
|
|
var payload JoinPayload
|
|
err = json.Unmarshal(msg, &payload)
|
|
if err != nil {
|
|
log.Println("Invalid JSON payload:", err)
|
|
clientsLock.Lock()
|
|
delete(clients, client)
|
|
clientsLock.Unlock()
|
|
client.MarkDisconnected()
|
|
client.Disconnect()
|
|
continue
|
|
}
|
|
|
|
// Simple API key check (replace with real auth)
|
|
if payload.APIKey == "" {
|
|
conn.WriteMessage(websocket.TextMessage, []byte("Missing API Key"))
|
|
continue
|
|
}
|
|
client.APIKey = payload.APIKey
|
|
|
|
// Handle channel subscription, add more here as we get more in.
|
|
switch payload.Channel {
|
|
case "logs":
|
|
client.Channels["logs"] = true
|
|
case "logServices":
|
|
for _, svc := range payload.Services {
|
|
client.Channels["logServices:"+svc] = true
|
|
}
|
|
case "labels":
|
|
client.Channels["labels"] = true
|
|
default:
|
|
conn.WriteMessage(websocket.TextMessage, []byte("Unknown channel"))
|
|
continue
|
|
}
|
|
|
|
// Save client info in DB
|
|
client.SaveToDB()
|
|
|
|
// Confirm subscription
|
|
resp := map[string]string{
|
|
"status": "subscribed",
|
|
"channel": payload.Channel,
|
|
}
|
|
respJSON, _ := json.Marshal(resp)
|
|
conn.WriteMessage(websocket.TextMessage, respJSON)
|
|
|
|
// You could now start pushing messages to client or keep connection open
|
|
// For demo, just wait and keep connection alive
|
|
}
|
|
}
|
|
|
|
func (c *Client) SaveToDB() {
|
|
// Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB
|
|
channels := make(map[string]interface{})
|
|
for ch := range c.Channels {
|
|
channels[ch] = true
|
|
}
|
|
|
|
clientRecord := &db.ClientRecord{
|
|
APIKey: c.APIKey,
|
|
IPAddress: c.IPAddress,
|
|
UserAgent: c.UserAgent,
|
|
Channels: db.JSONB(channels),
|
|
ConnectedAt: time.Now(),
|
|
LastHeartbeat: time.Now(),
|
|
}
|
|
|
|
if err := db.DB.Create(&clientRecord).Error; err != nil {
|
|
log.Println("❌ Error saving client:", err)
|
|
} else {
|
|
c.ClientID = clientRecord.ClientID // ✅ Assign the generated UUID back to the client
|
|
}
|
|
}
|
|
|
|
func (c *Client) MarkDisconnected() {
|
|
now := time.Now()
|
|
res := db.DB.Model(&db.ClientRecord{}).
|
|
Where("client_id = ?", c.ClientID).
|
|
Updates(map[string]interface{}{
|
|
"disconnected_at": &now,
|
|
})
|
|
|
|
if res.RowsAffected == 0 {
|
|
log.Println("⚠️ No rows updated for client_id:", c.ClientID)
|
|
}
|
|
if res.Error != nil {
|
|
log.Println("❌ Error updating disconnected_at:", res.Error)
|
|
}
|
|
}
|