Files
logistics_support_tool/backend/cmd/services/websocket/channel_manager.go

180 lines
3.7 KiB
Go

package websocket
import (
"encoding/json"
"log"
"strings"
"sync"
logging "lst.net/utils/logger"
)
type Channel struct {
Name string
Clients map[*Client]bool
Register chan *Client
Unregister chan *Client
Broadcast chan []byte
lock sync.RWMutex
}
var (
channels = make(map[string]*Channel)
channelsMu sync.RWMutex
)
// InitializeChannels creates and returns all channels
func InitializeChannels() {
channelsMu.Lock()
defer channelsMu.Unlock()
channels["logServices"] = NewChannel("logServices")
channels["labels"] = NewChannel("labels")
// Add more channels here as needed
}
func NewChannel(name string) *Channel {
return &Channel{
Name: name,
Clients: make(map[*Client]bool),
Register: make(chan *Client),
Unregister: make(chan *Client),
Broadcast: make(chan []byte),
}
}
func GetChannel(name string) (*Channel, bool) {
channelsMu.RLock()
defer channelsMu.RUnlock()
ch, exists := channels[name]
return ch, exists
}
func GetAllChannels() map[string]*Channel {
channelsMu.RLock()
defer channelsMu.RUnlock()
chs := make(map[string]*Channel)
for k, v := range channels {
chs[k] = v
}
return chs
}
func StartAllChannels() {
channelsMu.RLock()
defer channelsMu.RUnlock()
for _, ch := range channels {
go ch.RunChannel()
}
}
func CleanupChannels() {
channelsMu.Lock()
defer channelsMu.Unlock()
for _, ch := range channels {
close(ch.Broadcast)
// Add any other cleanup needed
}
channels = make(map[string]*Channel)
}
func StartBroadcasting(broadcaster chan logging.Message, channels map[string]*Channel) {
logger := logging.New()
go func() {
for msg := range broadcaster {
switch msg.Channel {
case "logServices":
// Just forward the message - filtering happens in RunChannel()
messageBytes, err := json.Marshal(msg)
if err != nil {
logger.Error("Error marshaling message", "websocket", map[string]interface{}{
"errors": err,
})
continue
}
channels["logServices"].Broadcast <- messageBytes
case "labels":
// Future labels handling
messageBytes, err := json.Marshal(msg)
if err != nil {
logger.Error("Error marshaling message", "websocket", map[string]interface{}{
"errors": err,
})
continue
}
channels["labels"].Broadcast <- messageBytes
default:
log.Printf("Received message for unknown channel: %s", msg.Channel)
}
}
}()
}
func contains(slice []string, item string) bool {
// Empty filter slice means "match all"
if len(slice) == 0 {
return true
}
// Case-insensitive comparison
item = strings.ToLower(item)
for _, s := range slice {
if strings.ToLower(s) == item {
return true
}
}
return false
}
// Updated Channel.RunChannel() for logServices filtering
func (ch *Channel) RunChannel() {
for {
select {
case client := <-ch.Register:
ch.lock.Lock()
ch.Clients[client] = true
ch.lock.Unlock()
case client := <-ch.Unregister:
ch.lock.Lock()
delete(ch.Clients, client)
ch.lock.Unlock()
case message := <-ch.Broadcast:
var msg logging.Message
if err := json.Unmarshal(message, &msg); err != nil {
continue
}
ch.lock.RLock()
for client := range ch.Clients {
// Special filtering for logServices
if ch.Name == "logServices" {
logLevel, _ := msg.Meta["level"].(string)
logService, _ := msg.Meta["service"].(string)
levelMatch := len(client.LogLevels) == 0 || contains(client.LogLevels, logLevel)
serviceMatch := len(client.Services) == 0 || contains(client.Services, logService)
if !levelMatch || !serviceMatch {
continue
}
}
select {
case client.Send <- message:
default:
ch.Unregister <- client
}
}
ch.lock.RUnlock()
}
}
}