forked from shiftwinting/FastGithub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TcpReverseProxyHandler.cs
77 lines (73 loc) · 3.01 KB
/
TcpReverseProxyHandler.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
using FastGithub.DomainResolve;
using Microsoft.AspNetCore.Connections;
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.HttpServer
{
/// <summary>
/// tcp反射代理处理者
/// </summary>
abstract class TcpReverseProxyHandler : ConnectionHandler
{
private readonly IDomainResolver domainResolver;
private readonly DnsEndPoint endPoint;
private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
/// <summary>
/// tcp反射代理处理者
/// </summary>
/// <param name="domainResolver"></param>
/// <param name="endPoint"></param>
public TcpReverseProxyHandler(IDomainResolver domainResolver, DnsEndPoint endPoint)
{
this.domainResolver = domainResolver;
this.endPoint = endPoint;
}
/// <summary>
/// tcp连接后
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
public override async Task OnConnectedAsync(ConnectionContext context)
{
var cancellationToken = context.ConnectionClosed;
using var connection = await this.CreateConnectionAsync(cancellationToken);
var task1 = connection.CopyToAsync(context.Transport.Output, cancellationToken);
var task2 = context.Transport.Input.CopyToAsync(connection, cancellationToken);
await Task.WhenAny(task1, task2);
}
/// <summary>
/// 创建连接
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync(CancellationToken cancellationToken)
{
var innerExceptions = new List<Exception>();
await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint, cancellationToken))
{
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(address, this.endPoint.Port, linkedTokenSource.Token);
return new NetworkStream(socket, ownsSocket: false);
}
catch (Exception ex)
{
socket.Dispose();
cancellationToken.ThrowIfCancellationRequested();
innerExceptions.Add(ex);
}
}
throw new AggregateException($"无法连接到{this.endPoint.Host}:{this.endPoint.Port}", innerExceptions);
}
}
}