Skip to content

Commit

Permalink
Feature(dns): support custom hosts
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamacro committed Jul 14, 2019
1 parent f867f02 commit 1a21c8e
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 78 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ experimental:
# listen: 0.0.0.0:53
# enhanced-mode: redir-host # or fake-ip
# # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it
# # experimental hosts, support wildcard (e.g. *.clash.dev Even *.foo.*.example.com)
# # static domain has a higher priority than wildcard domain (foo.example.com > *.example.com)
# # NOTE: hosts don't work with `fake-ip`
# hosts:
# '*.clash.dev': 127.0.0.1
# 'alpha.clash.dev': '::1'
# nameserver:
# - 114.114.114.114
# - tls://dns.rubyfish.cn:853 # dns over tls
Expand Down
26 changes: 26 additions & 0 deletions component/domain-trie/node.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package trie

// Node is the trie's node
type Node struct {
Data interface{}
children map[string]*Node
}

func (n *Node) getChild(s string) *Node {
return n.children[s]
}

func (n *Node) hasChild(s string) bool {
return n.getChild(s) != nil
}

func (n *Node) addChild(s string, child *Node) {
n.children[s] = child
}

func newNode(data interface{}) *Node {
return &Node{
Data: data,
children: map[string]*Node{},
}
}
84 changes: 84 additions & 0 deletions component/domain-trie/tire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package trie

import (
"errors"
"strings"
)

const (
wildcard = "*"
domainStep = "."
)

var (
// ErrInvalidDomain means insert domain is invalid
ErrInvalidDomain = errors.New("invalid domain")
)

// Trie contains the main logic for adding and searching nodes for domain segments.
// support wildcard domain (e.g *.google.com)
type Trie struct {
root *Node
}

// Insert adds a node to the trie.
// Support
// 1. www.example.com
// 2. *.example.com
// 3. subdomain.*.example.com
func (t *Trie) Insert(domain string, data interface{}) error {
parts := strings.Split(domain, domainStep)
if len(parts) < 2 {
return ErrInvalidDomain
}

node := t.root
// reverse storage domain part to save space
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
if !node.hasChild(part) {
node.addChild(part, newNode(nil))
}

node = node.getChild(part)
}

node.Data = data
return nil
}

// Search is the most important part of the Trie.
// Priority as:
// 1. static part
// 2. wildcard domain
func (t *Trie) Search(domain string) *Node {
parts := strings.Split(domain, domainStep)
if len(parts) < 2 {
return nil
}

n := t.root
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]

var child *Node
if !n.hasChild(part) {
if !n.hasChild(wildcard) {
return nil
}

child = n.getChild(wildcard)
} else {
child = n.getChild(part)
}

n = child
}

return n
}

// New returns a new, empty Trie.
func New() *Trie {
return &Trie{root: newNode(nil)}
}
69 changes: 69 additions & 0 deletions component/domain-trie/trie_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package trie

import (
"net"
"testing"
)

func TestTrie_Basic(t *testing.T) {
tree := New()
domains := []string{
"example.com",
"google.com",
}

for _, domain := range domains {
tree.Insert(domain, net.ParseIP("127.0.0.1"))
}

node := tree.Search("example.com")
if node == nil {
t.Error("should not recv nil")
}

if !node.Data.(net.IP).Equal(net.IP{127, 0, 0, 1}) {
t.Error("should equal 127.0.0.1")
}
}

func TestTrie_Wildcard(t *testing.T) {
tree := New()
domains := []string{
"*.example.com",
"sub.*.example.com",
"*.dev",
}

for _, domain := range domains {
tree.Insert(domain, nil)
}

if tree.Search("sub.example.com") == nil {
t.Error("should not recv nil")
}

if tree.Search("sub.foo.example.com") == nil {
t.Error("should not recv nil")
}

if tree.Search("foo.sub.example.com") != nil {
t.Error("should recv nil")
}

if tree.Search("foo.example.dev") != nil {
t.Error("should recv nil")
}
}

func TestTrie_Boundary(t *testing.T) {
tree := New()
tree.Insert("*.dev", nil)

if err := tree.Insert("com", nil); err == nil {
t.Error("should recv err")
}

if tree.Search("dev") != nil {
t.Error("should recv nil")
}
}
30 changes: 23 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
adapters "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/common/structure"
"github.com/Dreamacro/clash/component/auth"
trie "github.com/Dreamacro/clash/component/domain-trie"
"github.com/Dreamacro/clash/component/fakeip"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns"
Expand Down Expand Up @@ -42,6 +43,7 @@ type DNS struct {
IPv6 bool `yaml:"ipv6"`
NameServer []dns.NameServer `yaml:"nameserver"`
Fallback []dns.NameServer `yaml:"fallback"`
Hosts *trie.Trie `yaml:"-"`
Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange *fakeip.Pool
Expand All @@ -63,13 +65,14 @@ type Config struct {
}

type rawDNS struct {
Enable bool `yaml:"enable"`
IPv6 bool `yaml:"ipv6"`
NameServer []string `yaml:"nameserver"`
Fallback []string `yaml:"fallback"`
Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange string `yaml:"fake-ip-range"`
Enable bool `yaml:"enable"`
IPv6 bool `yaml:"ipv6"`
NameServer []string `yaml:"nameserver"`
Hosts map[string]string `yaml:"hosts"`
Fallback []string `yaml:"fallback"`
Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange string `yaml:"fake-ip-range"`
}

type rawConfig struct {
Expand Down Expand Up @@ -134,6 +137,7 @@ func readConfig(path string) (*rawConfig, error) {
DNS: rawDNS{
Enable: false,
FakeIPRange: "198.18.0.1/16",
Hosts: map[string]string{},
},
}
err = yaml.Unmarshal([]byte(data), &rawConfig)
Expand Down Expand Up @@ -518,6 +522,18 @@ func parseDNS(cfg rawDNS) (*DNS, error) {
return nil, err
}

if len(cfg.Hosts) != 0 {
tree := trie.New()
for domain, ipStr := range cfg.Hosts {
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("%s is not a valid IP", ipStr)
}
tree.Insert(domain, ip)
}
dnsCfg.Hosts = tree
}

if cfg.EnhancedMode == dns.FAKEIP {
_, ipnet, err := net.ParseCIDR(cfg.FakeIPRange)
if err != nil {
Expand Down
121 changes: 121 additions & 0 deletions dns/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package dns

import (
"fmt"
"net"
"strings"

"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/log"

D "github.com/miekg/dns"
)

type handler func(w D.ResponseWriter, r *D.Msg)

func withFakeIP(cache *cache.Cache, pool *fakeip.Pool) handler {
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]

cacheItem := cache.Get("fakeip:" + q.String())
if cache != nil {
msg := cacheItem.(*D.Msg).Copy()
setMsgTTL(msg, 1)
msg.SetReply(r)
w.WriteMsg(msg)
return
}

rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := pool.Get()
rr.A = ip
msg := r.Copy()
msg.Answer = []D.RR{rr}
putMsgToCache(cache, "fakeip:"+q.String(), msg)
putMsgToCache(cache, ip.String(), msg)

setMsgTTL(msg, 1)
return
}
}

func withResolver(resolver *Resolver) handler {
return func(w D.ResponseWriter, r *D.Msg) {
msg, err := resolver.Exchange(r)

if err != nil {
q := r.Question[0]
qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String())
log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err)
D.HandleFailed(w, r)
return
}
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}

func withHost(resolver *Resolver, next handler) handler {
hosts := resolver.hosts
if hosts == nil {
panic("dns/withHost: hosts should not be nil")
}

return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA {
next(w, r)
return
}

domain := strings.TrimRight(q.Name, ".")
host := hosts.Search(domain)
if host == nil {
next(w, r)
return
}

ip := host.Data.(net.IP)
if q.Qtype == D.TypeAAAA && ip.To16() == nil {
next(w, r)
return
} else if q.Qtype == D.TypeA && ip.To4() == nil {
next(w, r)
return
}

var rr D.RR
if q.Qtype == D.TypeAAAA {
record := &D.AAAA{}
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
record.AAAA = ip
rr = record
} else {
record := &D.A{}
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
record.A = ip
rr = record
}

msg := r.Copy()
msg.Answer = []D.RR{rr}
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}

func newHandler(resolver *Resolver) handler {
if resolver.IsFakeIP() {
return withFakeIP(resolver.cache, resolver.pool)
}

if resolver.hosts != nil {
return withHost(resolver, withResolver(resolver))
}

return withResolver(resolver)
}
Loading

0 comments on commit 1a21c8e

Please sign in to comment.