fix(wrapper): corrections to properly handle websockets :D

This commit is contained in:
2025-07-29 15:58:21 -05:00
parent a1a30cffd1
commit a761a3634b

View File

@@ -1,9 +1,13 @@
using System;
using System.IO;
using System.Net; using System.Net;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text; using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
var builder = WebApplication.CreateBuilder(args); var builder = WebApplication.CreateBuilder(args);
builder.Services.AddHttpClient("GoBackend", client => builder.Services.AddHttpClient("GoBackend", client =>
{ {
client.BaseAddress = new Uri("http://localhost:8080"); client.BaseAddress = new Uri("http://localhost:8080");
@@ -14,13 +18,30 @@ var app = builder.Build();
// Enable WebSocket support // Enable WebSocket support
app.UseWebSockets(); app.UseWebSockets();
// Logging method
void LogToFile(string message)
{
try
{
string logDir = Path.Combine(AppContext.BaseDirectory, "logs");
Directory.CreateDirectory(logDir);
string logFilePath = Path.Combine(logDir, "proxy_log.txt");
File.AppendAllText(logFilePath, $"{DateTime.UtcNow}: {message}{Environment.NewLine}");
}
catch (Exception ex)
{
// Handle potential errors writing to log file
Console.WriteLine($"Logging error: {ex.Message}");
}
}
// Middleware to handle WebSocket requests
app.Use(async (context, next) => app.Use(async (context, next) =>
{ {
// Proxy WebSocket requests for /lst/api/logger/logs (adjust path as needed) if (context.WebSockets.IsWebSocketRequest && context.Request.Path.StartsWithSegments("/ws"))
if (context.WebSockets.IsWebSocketRequest &&
context.Request.Path.StartsWithSegments("/lst/api/logger/logs"))
{ {
Console.WriteLine("WebSocket request received!"); LogToFile($"WebSocket request received for path: {context.Request.Path}");
try try
{ {
var backendUri = new UriBuilder("ws", "localhost", 8080) var backendUri = new UriBuilder("ws", "localhost", 8080)
@@ -30,46 +51,32 @@ app.Use(async (context, next) =>
}.Uri; }.Uri;
using var backendSocket = new ClientWebSocket(); using var backendSocket = new ClientWebSocket();
// Forward most headers except those managed by WebSocket protocol
foreach (var header in context.Request.Headers)
{
if (!header.Key.Equals("Host", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Upgrade", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Connection", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Sec-WebSocket-Key", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Sec-WebSocket-Version", StringComparison.OrdinalIgnoreCase))
{
backendSocket.Options.SetRequestHeader(header.Key, header.Value);
}
}
await backendSocket.ConnectAsync(backendUri, context.RequestAborted); await backendSocket.ConnectAsync(backendUri, context.RequestAborted);
using var frontendSocket = await context.WebSockets.AcceptWebSocketAsync(); using var frontendSocket = await context.WebSockets.AcceptWebSocketAsync();
var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();
// Bidirectional forwarding tasks // WebSocket forwarding tasks
var forwardToBackend = ForwardWebSocketAsync(frontendSocket, backendSocket, cts.Token); var forwardToBackend = ForwardWebSocketAsync(frontendSocket, backendSocket, cts.Token);
var forwardToFrontend = ForwardWebSocketAsync(backendSocket, frontendSocket, cts.Token); var forwardToFrontend = ForwardWebSocketAsync(backendSocket, frontendSocket, cts.Token);
await Task.WhenAny(forwardToBackend, forwardToFrontend); await Task.WhenAny(forwardToBackend, forwardToFrontend);
cts.Cancel(); cts.Cancel();
return;
} }
catch (Exception ex) catch (Exception ex)
{ {
LogToFile($"WebSocket proxy error: {ex.Message}");
context.Response.StatusCode = (int)HttpStatusCode.BadGateway; context.Response.StatusCode = (int)HttpStatusCode.BadGateway;
await context.Response.WriteAsync($"WebSocket proxy error: {ex.Message}"); await context.Response.WriteAsync($"WebSocket proxy error: {ex.Message}");
return;
} }
} }
else
await next(); {
await next();
}
}); });
// Proxy normal HTTP requests // Middleware to handle HTTP requests
app.Use(async (context, next) => app.Use(async (context, next) =>
{ {
if (context.WebSockets.IsWebSocketRequest) if (context.WebSockets.IsWebSocketRequest)
@@ -100,24 +107,24 @@ app.Use(async (context, next) =>
} }
var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, context.RequestAborted); var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, context.RequestAborted);
context.Response.StatusCode = (int)response.StatusCode; context.Response.StatusCode = (int)response.StatusCode;
foreach (var header in response.Headers) foreach (var header in response.Headers)
{ {
context.Response.Headers[header.Key] = header.Value.ToArray(); context.Response.Headers[header.Key] = header.Value.ToArray();
} }
foreach (var header in response.Content.Headers) foreach (var header in response.Content.Headers)
{ {
context.Response.Headers[header.Key] = header.Value.ToArray(); context.Response.Headers[header.Key] = header.Value.ToArray();
} }
context.Response.Headers.Remove("transfer-encoding"); context.Response.Headers.Remove("transfer-encoding");
await response.Content.CopyToAsync(context.Response.Body); await response.Content.CopyToAsync(context.Response.Body);
} }
catch (HttpRequestException ex) catch (HttpRequestException ex)
{ {
LogToFile($"HTTP proxy error: {ex.Message}");
context.Response.StatusCode = (int)HttpStatusCode.BadGateway; context.Response.StatusCode = (int)HttpStatusCode.BadGateway;
await context.Response.WriteAsync($"Backend request failed: {ex.Message}"); await context.Response.WriteAsync($"Backend request failed: {ex.Message}");
} }
@@ -141,9 +148,10 @@ async Task ForwardWebSocketAsync(WebSocket source, WebSocket destination, Cancel
await destination.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken); await destination.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken);
} }
} }
catch (WebSocketException) catch (WebSocketException ex)
{ {
// Normal close or network error LogToFile($"WebSocket forwarding error: {ex.Message}");
await destination.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, "Error", cancellationToken);
} }
} }