Files
logistics_support_tool/backend/cmd/services/websocket/handler.go

164 lines
3.7 KiB
Go

package socketio
import (
"encoding/json"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"lst.net/utils/db"
)
type JoinPayload struct {
Channel string `json:"channel"`
Services []string `json:"services,omitempty"`
APIKey string `json:"apiKey"`
}
// type Channel struct {
// Name string
// Clients map[*Client]bool
// Register chan *Client
// Unregister chan *Client
// Broadcast chan Message
// }
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true }, // allow all origins; customize for prod
}
func SocketHandler(c *gin.Context) {
// Upgrade HTTP to websocket
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Println("Failed to upgrade:", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
defer conn.Close()
// Create client struct
client := &Client{
Conn: conn,
IPAddress: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
Channels: make(map[string]bool),
}
clientsLock.Lock()
clients[client] = true
clientsLock.Unlock()
defer func() {
clientsLock.Lock()
delete(clients, client)
clientsLock.Unlock()
client.MarkDisconnected()
client.Disconnect()
conn.Close()
}()
for {
// Read message from client
_, msg, err := conn.ReadMessage()
if err != nil {
log.Println("Read error:", err)
clientsLock.Lock()
delete(clients, client)
clientsLock.Unlock()
client.MarkDisconnected()
client.Disconnect()
break
}
var payload JoinPayload
err = json.Unmarshal(msg, &payload)
if err != nil {
log.Println("Invalid JSON payload:", err)
clientsLock.Lock()
delete(clients, client)
clientsLock.Unlock()
client.MarkDisconnected()
client.Disconnect()
continue
}
// Simple API key check (replace with real auth)
if payload.APIKey == "" {
conn.WriteMessage(websocket.TextMessage, []byte("Missing API Key"))
continue
}
client.APIKey = payload.APIKey
// Handle channel subscription, add more here as we get more in.
switch payload.Channel {
case "logs":
client.Channels["logs"] = true
case "logServices":
for _, svc := range payload.Services {
client.Channels["logServices:"+svc] = true
}
case "labels":
client.Channels["labels"] = true
default:
conn.WriteMessage(websocket.TextMessage, []byte("Unknown channel"))
continue
}
// Save client info in DB
client.SaveToDB()
// Confirm subscription
resp := map[string]string{
"status": "subscribed",
"channel": payload.Channel,
}
respJSON, _ := json.Marshal(resp)
conn.WriteMessage(websocket.TextMessage, respJSON)
// You could now start pushing messages to client or keep connection open
// For demo, just wait and keep connection alive
}
}
func (c *Client) SaveToDB() {
// Convert c.Channels (map[string]bool) to map[string]interface{} for JSONB
channels := make(map[string]interface{})
for ch := range c.Channels {
channels[ch] = true
}
clientRecord := &db.ClientRecord{
APIKey: c.APIKey,
IPAddress: c.IPAddress,
UserAgent: c.UserAgent,
Channels: db.JSONB(channels),
ConnectedAt: time.Now(),
LastHeartbeat: time.Now(),
}
if err := db.DB.Create(&clientRecord).Error; err != nil {
log.Println("❌ Error saving client:", err)
} else {
c.ClientID = clientRecord.ClientID // ✅ Assign the generated UUID back to the client
}
}
func (c *Client) MarkDisconnected() {
now := time.Now()
res := db.DB.Model(&db.ClientRecord{}).
Where("client_id = ?", c.ClientID).
Updates(map[string]interface{}{
"disconnected_at": &now,
})
if res.RowsAffected == 0 {
log.Println("⚠️ No rows updated for client_id:", c.ClientID)
}
if res.Error != nil {
log.Println("❌ Error updating disconnected_at:", res.Error)
}
}