Skip to content

Commit

Permalink
Feature: add PROCESS-NAME rule for linux (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kr328 authored Jul 22, 2020
1 parent 20eff20 commit 4f73410
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 1 deletion.
1 change: 1 addition & 0 deletions rules/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ var (
errPayload = errors.New("payload error")
errParams = errors.New("params error")
ErrPlatformNotSupport = errors.New("not support on this platform")
ErrInvalidNetwork = errors.New("invalid network")

noResolve = "no-resolve"
)
Expand Down
291 changes: 291 additions & 0 deletions rules/process_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
package rules

import (
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net"
"path"
"path/filepath"
"strconv"
"strings"
"syscall"
"unsafe"

"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)

// from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62
func init() {
var x uint32 = 0x01020304
if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
nativeEndian = binary.BigEndian
} else {
nativeEndian = binary.LittleEndian
}
}

type SocketResolver func(metadata *C.Metadata) (inode, uid int, err error)
type ProcessNameResolver func(inode, uid int) (name string, err error)

// export for android
var (
DefaultSocketResolver SocketResolver = resolveSocketByNetlink
DefaultProcessNameResolver ProcessNameResolver = resolveProcessNameByProcSeach
)

type Process struct {
adapter string
process string
}

func (p *Process) RuleType() C.RuleType {
return C.Process
}

func (p *Process) Match(metadata *C.Metadata) bool {
key := fmt.Sprintf("%s:%s:%s", metadata.NetWork.String(), metadata.SrcIP.String(), metadata.SrcPort)
cached, hit := processCache.Get(key)
if !hit {
processName, err := resolveProcessName(metadata)
if err != nil {
log.Debugln("[%s] Resolve process of %s failure: %s", C.Process.String(), key, err.Error())
}

processCache.Set(key, processName)

cached = processName
}

return strings.EqualFold(cached.(string), p.process)
}

func (p *Process) Adapter() string {
return p.adapter
}

func (p *Process) Payload() string {
return p.process
}

func (p *Process) NoResolveIP() bool {
return true
}

func NewProcess(process string, adapter string) (*Process, error) {
return &Process{
adapter: adapter,
process: process,
}, nil
}

const (
sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48
socketDiagByFamily = 20
pathProc = "/proc"
)

var nativeEndian binary.ByteOrder = binary.LittleEndian

var processCache = cache.NewLRUCache(cache.WithAge(2), cache.WithSize(64))

func resolveProcessName(metadata *C.Metadata) (string, error) {
inode, uid, err := DefaultSocketResolver(metadata)
if err != nil {
return "", err
}

return DefaultProcessNameResolver(inode, uid)
}

func resolveSocketByNetlink(metadata *C.Metadata) (int, int, error) {
var family byte
var protocol byte

switch metadata.NetWork {
case C.TCP:
protocol = syscall.IPPROTO_TCP
case C.UDP:
protocol = syscall.IPPROTO_UDP
default:
return 0, 0, ErrInvalidNetwork
}

if metadata.SrcIP.To4() != nil {
family = syscall.AF_INET
} else {
family = syscall.AF_INET6
}

srcPort, err := strconv.Atoi(metadata.SrcPort)
if err != nil {
return 0, 0, err
}

req := packSocketDiagRequest(family, protocol, metadata.SrcIP, uint16(srcPort))

socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG)
if err != nil {
return 0, 0, err
}
defer syscall.Close(socket)

syscall.SetNonblock(socket, true)
syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, &syscall.Timeval{Usec: 50})
syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, &syscall.Timeval{Usec: 50})

if err := syscall.Connect(socket, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK,
Pad: 0,
Pid: 0,
Groups: 0,
}); err != nil {
return 0, 0, err
}

if _, err := syscall.Write(socket, req); err != nil {
return 0, 0, err
}

rb := pool.Get(pool.RelayBufferSize)
defer pool.Put(rb)

n, err := syscall.Read(socket, rb)
if err != nil {
return 0, 0, err
}

messages, err := syscall.ParseNetlinkMessage(rb[:n])
if err != nil {
return 0, 0, err
} else if len(messages) == 0 {
return 0, 0, io.ErrUnexpectedEOF
}

message := messages[0]
if message.Header.Type&syscall.NLMSG_ERROR != 0 {
return 0, 0, syscall.ESRCH
}

uid, inode := unpackSocketDiagResponse(&messages[0])

return int(uid), int(inode), nil
}

func packSocketDiagRequest(family, protocol byte, source net.IP, sourcePort uint16) []byte {
s := make([]byte, 16)

if v4 := source.To4(); v4 != nil {
copy(s, v4)
} else {
copy(s, source)
}

buf := make([]byte, sizeOfSocketDiagRequest)

nativeEndian.PutUint32(buf[0:4], sizeOfSocketDiagRequest)
nativeEndian.PutUint16(buf[4:6], socketDiagByFamily)
nativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST|syscall.NLM_F_DUMP)
nativeEndian.PutUint32(buf[8:12], 0)
nativeEndian.PutUint32(buf[12:16], 0)

buf[16] = family
buf[17] = protocol
buf[18] = 0
buf[19] = 0
nativeEndian.PutUint32(buf[20:24], 0xFFFFFFFF)

binary.BigEndian.PutUint16(buf[24:26], sourcePort)
binary.BigEndian.PutUint16(buf[26:28], 0)

copy(buf[28:44], s)
copy(buf[44:60], net.IPv6zero)

nativeEndian.PutUint32(buf[60:64], 0)
nativeEndian.PutUint64(buf[64:72], 0xFFFFFFFFFFFFFFFF)

return buf
}

func unpackSocketDiagResponse(msg *syscall.NetlinkMessage) (inode, uid uint32) {
if len(msg.Data) < 72 {
return 0, 0
}

data := msg.Data

uid = nativeEndian.Uint32(data[64:68])
inode = nativeEndian.Uint32(data[68:72])

return
}

func resolveProcessNameByProcSeach(inode, _ int) (string, error) {
files, err := ioutil.ReadDir(pathProc)
if err != nil {
return "", err
}

buffer := make([]byte, syscall.PathMax)
socket := []byte(fmt.Sprintf("socket:[%d]", inode))

for _, f := range files {
if !isPid(f.Name()) {
continue
}

processPath := path.Join(pathProc, f.Name())
fdPath := path.Join(processPath, "fd")

fds, err := ioutil.ReadDir(fdPath)
if err != nil {
continue
}

for _, fd := range fds {
n, err := syscall.Readlink(path.Join(fdPath, fd.Name()), buffer)
if err != nil {
continue
}

if bytes.Compare(buffer[:n], socket) == 0 {
cmdline, err := ioutil.ReadFile(path.Join(processPath, "cmdline"))
if err != nil {
return "", err
}

return splitCmdline(cmdline), nil
}
}
}

return "", syscall.ESRCH
}

func splitCmdline(cmdline []byte) string {
indexOfEndOfString := len(cmdline)

for i, c := range cmdline {
if c == 0 {
indexOfEndOfString = i
break
}
}

return filepath.Base(string(cmdline[:indexOfEndOfString]))
}

func isPid(s string) bool {
for _, s := range s {
if s < '0' || s > '9' {
return false
}
}

return true
}
2 changes: 1 addition & 1 deletion rules/process_other.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build !darwin
// +build !darwin,!linux

package rules

Expand Down

0 comments on commit 4f73410

Please sign in to comment.