225 lines
5.5 KiB
Go
225 lines
5.5 KiB
Go
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type JoinPayload struct {
|
|
Channel string `json:"channel"`
|
|
APIKey string `json:"apiKey"`
|
|
Services []string `json:"services,omitempty"`
|
|
Levels []string `json:"levels,omitempty"`
|
|
Labels []string `json:"labels,omitempty"`
|
|
}
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true }, // allow all origins; customize for prod
|
|
HandshakeTimeout: 15 * time.Second,
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
EnableCompression: true,
|
|
}
|
|
|
|
func SocketHandler(c *gin.Context, channels map[string]*Channel) {
|
|
// Upgrade HTTP to WebSocket
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
log.Println("WebSocket upgrade failed:", err)
|
|
return
|
|
}
|
|
//defer conn.Close()
|
|
|
|
// Create new client
|
|
client := &Client{
|
|
Conn: conn,
|
|
APIKey: "exampleAPIKey",
|
|
Send: make(chan []byte, 256), // Buffered channel
|
|
Channels: make(map[string]bool),
|
|
IPAddress: c.ClientIP(),
|
|
UserAgent: c.Request.UserAgent(),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
client.isAlive.Store(true)
|
|
// Add to global clients map
|
|
clientsMu.Lock()
|
|
clients[client] = true
|
|
clientsMu.Unlock()
|
|
|
|
// Save initial connection to DB
|
|
client.SaveToDB()
|
|
// Save initial connection to DB
|
|
// if err := client.SaveToDB(); err != nil {
|
|
// log.Println("Failed to save client to DB:", err)
|
|
// conn.Close()
|
|
// return
|
|
// }
|
|
|
|
// Set handlers
|
|
conn.SetPingHandler(func(string) error {
|
|
return nil // Auto-responds with pong
|
|
})
|
|
|
|
conn.SetPongHandler(func(string) error {
|
|
now := time.Now()
|
|
client.markActive() // Track last pong time
|
|
client.lastActive = now
|
|
client.updateHeartbeat()
|
|
return nil
|
|
})
|
|
|
|
// Start server-side ping ticker
|
|
go client.startServerPings()
|
|
|
|
defer func() {
|
|
// Unregister from all channels
|
|
for channelName := range client.Channels {
|
|
if ch, exists := channels[channelName]; exists {
|
|
ch.Unregister <- client
|
|
}
|
|
}
|
|
|
|
// Remove from global clients map
|
|
clientsMu.Lock()
|
|
delete(clients, client)
|
|
clientsMu.Unlock()
|
|
|
|
// Mark disconnected in DB
|
|
client.MarkDisconnected()
|
|
|
|
// Close connection
|
|
conn.Close()
|
|
log.Printf("Client disconnected: %s", client.ClientID)
|
|
}()
|
|
|
|
// Send welcome message immediately
|
|
welcomeMsg := map[string]string{
|
|
"status": "connected",
|
|
"message": "Welcome to the WebSocket server. Send subscription request to begin.",
|
|
}
|
|
if err := conn.WriteJSON(welcomeMsg); err != nil {
|
|
log.Println("Failed to send welcome message:", err)
|
|
return
|
|
}
|
|
|
|
// Message handling goroutine
|
|
go func() {
|
|
defer func() {
|
|
// Cleanup on disconnect
|
|
for channelName := range client.Channels {
|
|
if ch, exists := channels[channelName]; exists {
|
|
ch.Unregister <- client
|
|
}
|
|
}
|
|
close(client.Send)
|
|
client.MarkDisconnected()
|
|
}()
|
|
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
|
|
log.Printf("Client disconnected unexpectedly: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
var payload struct {
|
|
Channel string `json:"channel"`
|
|
APIKey string `json:"apiKey"`
|
|
Services []string `json:"services,omitempty"`
|
|
Levels []string `json:"levels,omitempty"`
|
|
Labels []string `json:"labels,omitempty"`
|
|
}
|
|
|
|
if err := json.Unmarshal(msg, &payload); err != nil {
|
|
conn.WriteJSON(map[string]string{"error": "invalid payload format"})
|
|
continue
|
|
}
|
|
|
|
// Validate API key (implement your own validateAPIKey function)
|
|
// if payload.APIKey == "" || !validateAPIKey(payload.APIKey) {
|
|
// conn.WriteJSON(map[string]string{"error": "invalid or missing API key"})
|
|
// continue
|
|
// }
|
|
|
|
if payload.APIKey == "" {
|
|
conn.WriteMessage(websocket.TextMessage, []byte("Missing API Key"))
|
|
continue
|
|
}
|
|
client.APIKey = payload.APIKey
|
|
|
|
// Handle channel subscription
|
|
switch payload.Channel {
|
|
case "logServices":
|
|
// Unregister from other channels if needed
|
|
if client.Channels["labels"] {
|
|
channels["labels"].Unregister <- client
|
|
delete(client.Channels, "labels")
|
|
}
|
|
|
|
// Update client filters
|
|
client.Services = payload.Services
|
|
client.LogLevels = payload.Levels
|
|
|
|
// Register to channel
|
|
channels["logServices"].Register <- client
|
|
client.Channels["logServices"] = true
|
|
|
|
conn.WriteJSON(map[string]string{
|
|
"status": "subscribed",
|
|
"channel": "logServices",
|
|
})
|
|
|
|
case "labels":
|
|
// Unregister from other channels if needed
|
|
if client.Channels["logServices"] {
|
|
channels["logServices"].Unregister <- client
|
|
delete(client.Channels, "logServices")
|
|
}
|
|
|
|
// Set label filters if provided
|
|
if payload.Labels != nil {
|
|
client.Labels = payload.Labels
|
|
}
|
|
|
|
// Register to channel
|
|
channels["labels"].Register <- client
|
|
client.Channels["labels"] = true
|
|
|
|
// Update DB record
|
|
client.SaveToDB()
|
|
// if err := client.SaveToDB(); err != nil {
|
|
// log.Println("Failed to update client labels:", err)
|
|
// }
|
|
|
|
conn.WriteJSON(map[string]interface{}{
|
|
"status": "subscribed",
|
|
"channel": "labels",
|
|
"filters": client.Labels,
|
|
})
|
|
|
|
default:
|
|
conn.WriteJSON(map[string]string{
|
|
"error": "invalid channel",
|
|
"available_channels": "logServices, labels",
|
|
})
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Send messages to client
|
|
for message := range client.Send {
|
|
if err := conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
|
log.Println("Write error:", err)
|
|
break
|
|
}
|
|
}
|
|
}
|