forked from nadoo/glider
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrule.go
133 lines (107 loc) · 2.69 KB
/
rule.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package main
import (
"log"
"net"
"strings"
"sync"
)
// RuleDialer struct
type RuleDialer struct {
gDialer Dialer
domainMap sync.Map
ipMap sync.Map
cidrMap sync.Map
}
// NewRuleDialer returns a new rule dialer
func NewRuleDialer(rules []*RuleConf, gDialer Dialer) *RuleDialer {
rd := &RuleDialer{gDialer: gDialer}
for _, r := range rules {
var forwarders []Dialer
for _, chain := range r.Forward {
var forward Dialer
var err error
for _, url := range strings.Split(chain, ",") {
forward, err = DialerFromURL(url, forward)
if err != nil {
log.Fatal(err)
}
}
forwarders = append(forwarders, forward)
}
sd := NewStrategyDialer(r.Strategy, forwarders, r.CheckWebSite, r.CheckDuration)
for _, domain := range r.Domain {
rd.domainMap.Store(domain, sd)
}
for _, ip := range r.IP {
rd.ipMap.Store(ip, sd)
}
for _, s := range r.CIDR {
if _, cidr, err := net.ParseCIDR(s); err == nil {
rd.cidrMap.Store(cidr, sd)
}
}
}
return rd
}
// Addr returns RuleDialer's address, always be "RULES"
func (rd *RuleDialer) Addr() string { return "RULES" }
// NextDialer return next dialer according to rule
func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
// TODO: change to index finders
host, _, err := net.SplitHostPort(dstAddr)
if err != nil {
// TODO: check here
// logf("proxy-rule SplitHostPort ERROR: %s", err)
return p.gDialer
}
// find ip
if ip := net.ParseIP(host); ip != nil {
// check ip
if d, ok := p.ipMap.Load(ip.String()); ok {
return d.(Dialer)
}
var ret Dialer
// check cidr
p.cidrMap.Range(func(key, value interface{}) bool {
cidr := key.(*net.IPNet)
if cidr.Contains(ip) {
ret = value.(Dialer)
return false
}
return true
})
if ret != nil {
return ret
}
}
domainParts := strings.Split(host, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
// find in domainMap
if d, ok := p.domainMap.Load(domain); ok {
return d.(Dialer)
}
}
return p.gDialer
}
// Dial dials to targer addr and return a conn
func (rd *RuleDialer) Dial(network, addr string) (net.Conn, error) {
return rd.NextDialer(addr).Dial(network, addr)
}
// AddDomainIP used to update ipMap rules according to domainMap rule
func (rd *RuleDialer) AddDomainIP(domain, ip string) error {
if ip != "" {
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
// find in domainMap
if d, ok := rd.domainMap.Load(domain); ok {
rd.ipMap.Store(ip, d)
logf("rule: add domain: %s, ip: %s\n", domain, ip)
}
}
}
return nil
}