Files
logistics_support_tool/LstWrapper/Program.cs

150 lines
5.1 KiB
C#

using System.Net;
using System.Net.WebSockets;
using System.Text;
var builder = WebApplication.CreateBuilder(args);
builder.Services.AddHttpClient("GoBackend", client =>
{
client.BaseAddress = new Uri("http://localhost:8080");
});
var app = builder.Build();
// Enable WebSocket support
app.UseWebSockets();
app.Use(async (context, next) =>
{
// Proxy WebSocket requests for /lst/api/logger/logs (adjust path as needed)
if (context.WebSockets.IsWebSocketRequest &&
context.Request.Path.StartsWithSegments("/lst/api/logger/logs"))
{
try
{
var backendUri = new UriBuilder("ws", "localhost", 8080)
{
Path = context.Request.Path,
Query = context.Request.QueryString.ToString()
}.Uri;
using var backendSocket = new ClientWebSocket();
// Forward most headers except those managed by WebSocket protocol
foreach (var header in context.Request.Headers)
{
if (!header.Key.Equals("Host", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Upgrade", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Connection", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Sec-WebSocket-Key", StringComparison.OrdinalIgnoreCase) &&
!header.Key.Equals("Sec-WebSocket-Version", StringComparison.OrdinalIgnoreCase))
{
backendSocket.Options.SetRequestHeader(header.Key, header.Value);
}
}
await backendSocket.ConnectAsync(backendUri, context.RequestAborted);
using var frontendSocket = await context.WebSockets.AcceptWebSocketAsync();
var cts = new CancellationTokenSource();
// Bidirectional forwarding tasks
var forwardToBackend = ForwardWebSocketAsync(frontendSocket, backendSocket, cts.Token);
var forwardToFrontend = ForwardWebSocketAsync(backendSocket, frontendSocket, cts.Token);
await Task.WhenAny(forwardToBackend, forwardToFrontend);
cts.Cancel();
return;
}
catch (Exception ex)
{
context.Response.StatusCode = (int)HttpStatusCode.BadGateway;
await context.Response.WriteAsync($"WebSocket proxy error: {ex.Message}");
return;
}
}
await next();
});
// Proxy normal HTTP requests
app.Use(async (context, next) =>
{
if (context.WebSockets.IsWebSocketRequest)
{
await next();
return;
}
var client = context.RequestServices.GetRequiredService<IHttpClientFactory>().CreateClient("GoBackend");
try
{
var request = new HttpRequestMessage(new HttpMethod(context.Request.Method),
context.Request.Path + context.Request.QueryString);
foreach (var header in context.Request.Headers)
{
if (!request.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()))
{
request.Content ??= new StreamContent(context.Request.Body);
request.Content.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
}
}
if (context.Request.ContentLength > 0 && request.Content == null)
{
request.Content = new StreamContent(context.Request.Body);
}
var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, context.RequestAborted);
context.Response.StatusCode = (int)response.StatusCode;
foreach (var header in response.Headers)
{
context.Response.Headers[header.Key] = header.Value.ToArray();
}
foreach (var header in response.Content.Headers)
{
context.Response.Headers[header.Key] = header.Value.ToArray();
}
context.Response.Headers.Remove("transfer-encoding");
await response.Content.CopyToAsync(context.Response.Body);
}
catch (HttpRequestException ex)
{
context.Response.StatusCode = (int)HttpStatusCode.BadGateway;
await context.Response.WriteAsync($"Backend request failed: {ex.Message}");
}
});
async Task ForwardWebSocketAsync(WebSocket source, WebSocket destination, CancellationToken cancellationToken)
{
var buffer = new byte[4 * 1024];
try
{
while (source.State == WebSocketState.Open &&
destination.State == WebSocketState.Open &&
!cancellationToken.IsCancellationRequested)
{
var result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken);
if (result.MessageType == WebSocketMessageType.Close)
{
await destination.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken);
break;
}
await destination.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken);
}
}
catch (WebSocketException)
{
// Normal close or network error
}
}
app.Run();