diff --git a/LstWrapper/Program.cs b/LstWrapper/Program.cs index 8cc918a..760f755 100644 --- a/LstWrapper/Program.cs +++ b/LstWrapper/Program.cs @@ -1,9 +1,13 @@ +using System; +using System.IO; using System.Net; using System.Net.WebSockets; -using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; var builder = WebApplication.CreateBuilder(args); - builder.Services.AddHttpClient("GoBackend", client => { client.BaseAddress = new Uri("http://localhost:8080"); @@ -14,13 +18,30 @@ var app = builder.Build(); // Enable WebSocket support app.UseWebSockets(); +// Logging method +void LogToFile(string message) +{ + try + { + string logDir = Path.Combine(AppContext.BaseDirectory, "logs"); + Directory.CreateDirectory(logDir); + string logFilePath = Path.Combine(logDir, "proxy_log.txt"); + File.AppendAllText(logFilePath, $"{DateTime.UtcNow}: {message}{Environment.NewLine}"); + } + catch (Exception ex) + { + // Handle potential errors writing to log file + Console.WriteLine($"Logging error: {ex.Message}"); + } +} + +// Middleware to handle WebSocket requests app.Use(async (context, next) => { - // Proxy WebSocket requests for /lst/api/logger/logs (adjust path as needed) - if (context.WebSockets.IsWebSocketRequest && - context.Request.Path.StartsWithSegments("/lst/api/logger/logs")) + if (context.WebSockets.IsWebSocketRequest && context.Request.Path.StartsWithSegments("/ws")) { - Console.WriteLine("WebSocket request received!"); + LogToFile($"WebSocket request received for path: {context.Request.Path}"); + try { var backendUri = new UriBuilder("ws", "localhost", 8080) @@ -30,46 +51,32 @@ app.Use(async (context, next) => }.Uri; using var backendSocket = new ClientWebSocket(); - - // Forward most headers except those managed by WebSocket protocol - foreach (var header in context.Request.Headers) - { - if (!header.Key.Equals("Host", StringComparison.OrdinalIgnoreCase) && - !header.Key.Equals("Upgrade", StringComparison.OrdinalIgnoreCase) && - !header.Key.Equals("Connection", StringComparison.OrdinalIgnoreCase) && - !header.Key.Equals("Sec-WebSocket-Key", StringComparison.OrdinalIgnoreCase) && - !header.Key.Equals("Sec-WebSocket-Version", StringComparison.OrdinalIgnoreCase)) - { - backendSocket.Options.SetRequestHeader(header.Key, header.Value); - } - } - await backendSocket.ConnectAsync(backendUri, context.RequestAborted); using var frontendSocket = await context.WebSockets.AcceptWebSocketAsync(); - var cts = new CancellationTokenSource(); - // Bidirectional forwarding tasks + // WebSocket forwarding tasks var forwardToBackend = ForwardWebSocketAsync(frontendSocket, backendSocket, cts.Token); var forwardToFrontend = ForwardWebSocketAsync(backendSocket, frontendSocket, cts.Token); await Task.WhenAny(forwardToBackend, forwardToFrontend); cts.Cancel(); - return; } catch (Exception ex) { + LogToFile($"WebSocket proxy error: {ex.Message}"); context.Response.StatusCode = (int)HttpStatusCode.BadGateway; await context.Response.WriteAsync($"WebSocket proxy error: {ex.Message}"); - return; } } - - await next(); + else + { + await next(); + } }); -// Proxy normal HTTP requests +// Middleware to handle HTTP requests app.Use(async (context, next) => { if (context.WebSockets.IsWebSocketRequest) @@ -100,24 +107,24 @@ app.Use(async (context, next) => } var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, context.RequestAborted); - context.Response.StatusCode = (int)response.StatusCode; foreach (var header in response.Headers) { context.Response.Headers[header.Key] = header.Value.ToArray(); } + foreach (var header in response.Content.Headers) { context.Response.Headers[header.Key] = header.Value.ToArray(); } context.Response.Headers.Remove("transfer-encoding"); - await response.Content.CopyToAsync(context.Response.Body); } catch (HttpRequestException ex) { + LogToFile($"HTTP proxy error: {ex.Message}"); context.Response.StatusCode = (int)HttpStatusCode.BadGateway; await context.Response.WriteAsync($"Backend request failed: {ex.Message}"); } @@ -141,10 +148,11 @@ async Task ForwardWebSocketAsync(WebSocket source, WebSocket destination, Cancel await destination.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken); } } - catch (WebSocketException) + catch (WebSocketException ex) { - // Normal close or network error + LogToFile($"WebSocket forwarding error: {ex.Message}"); + await destination.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, "Error", cancellationToken); } } -app.Run(); +app.Run(); \ No newline at end of file