forked from open-policy-agent/opa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokens.go
199 lines (166 loc) · 5.65 KB
/
tokens.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"crypto"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
var (
jwtEncKey = ast.StringTerm("enc")
jwtCtyKey = ast.StringTerm("cty")
)
// Implements JWT decoding/validation based on RFC 7519 Section 7.2:
// https://tools.ietf.org/html/rfc7519#section-7.2
// It does no data validation, it merely checks that the given string
// represents a structurally valid JWT. It supports JWTs using JWS compact
// serialization.
func builtinJWTDecode(a ast.Value) (ast.Value, error) {
astEncode, err := builtins.StringOperand(a, 1)
encoding := string(astEncode)
if !strings.Contains(encoding, ".") {
return nil, errors.New("encoded JWT had no period separators")
}
parts := strings.Split(encoding, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("encoded JWT must have 3 sections, found %d", len(parts))
}
h, err := builtinBase64UrlDecode(ast.String(parts[0]))
if err != nil {
return nil, fmt.Errorf("JWT header had invalid encoding: %v", err)
}
header, err := validateJWTHeader(string(h.(ast.String)))
if err != nil {
return nil, err
}
p, err := builtinBase64UrlDecode(ast.String(parts[1]))
if err != nil {
return nil, fmt.Errorf("JWT payload had invalid encoding: %v", err)
}
if cty := header.Get(jwtCtyKey); cty != nil {
ctyVal := string(cty.Value.(ast.String))
// It is possible for the contents of a token to be another
// token as a result of nested signing or encryption. To handle
// the case where we are given a token such as this, we check
// the content type and recurse on the payload if the content
// is "JWT".
// When the payload is itself another encoded JWT, then its
// contents are quoted (behavior of https://jwt.io/). To fix
// this, remove leading and trailing quotes.
if ctyVal == "JWT" {
p, err = builtinTrim(p, ast.String(`"'`))
if err != nil {
panic("not reached")
}
return builtinJWTDecode(p)
}
}
payload, err := extractJSONObject(string(p.(ast.String)))
if err != nil {
return nil, err
}
s, err := builtinBase64UrlDecode(ast.String(parts[2]))
if err != nil {
return nil, fmt.Errorf("JWT signature had invalid encoding: %v", err)
}
sign := hex.EncodeToString([]byte(s.(ast.String)))
arr := make(ast.Array, 3)
arr[0] = ast.NewTerm(header)
arr[1] = ast.NewTerm(payload)
arr[2] = ast.StringTerm(sign)
return arr, nil
}
// Implements RS256 JWT signature verification
func builtinJWTVerifyRS256(a ast.Value, b ast.Value) (ast.Value, error) {
// Process the token
astToken, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
token := string(astToken)
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("encoded JWT must have 3 sections, found %d", len(parts))
}
headerPayload := []byte(strings.Join(parts[:2], "."))
sign, err := builtinBase64UrlDecode(ast.String(parts[2]))
if err != nil {
return nil, err
}
signature := []byte(sign.(ast.String))
// Process PEM encoded certificate input
astCertificate, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
certificate := string(astCertificate)
block, rest := pem.Decode([]byte(certificate))
if block == nil || block.Type != "CERTIFICATE" || len(rest) > 0 {
return nil, fmt.Errorf("failed to decode PEM block containing certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "PEM parse error")
}
// Get public key
publicKey := cert.PublicKey.(*rsa.PublicKey)
// Validate the JWT signature
err = rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, getInputSHA(headerPayload), signature)
if err != nil {
return ast.Boolean(false), nil
}
return ast.Boolean(true), nil
}
// Extract, validate and return the JWT header as an ast.Object.
func validateJWTHeader(h string) (ast.Object, error) {
header, err := extractJSONObject(h)
if err != nil {
return nil, fmt.Errorf("bad JWT header: %v", err)
}
// There are two kinds of JWT tokens, a JSON Web Signature (JWS) and
// a JSON Web Encryption (JWE). The latter is very involved, and we
// won't support it for now.
// This code checks which kind of JWT we are dealing with according to
// RFC 7516 Section 9: https://tools.ietf.org/html/rfc7516#section-9
if header.Get(jwtEncKey) != nil {
return nil, errors.New("JWT is a JWE object, which is not supported")
}
return header, nil
}
func extractJSONObject(s string) (ast.Object, error) {
// XXX: This code relies on undocumented behavior of Go's
// json.Unmarshal using the last occurrence of duplicate keys in a JSON
// Object. If duplicate keys are present in a JWT, the last must be
// used or the token rejected. Since detecting duplicates is tantamount
// to parsing it ourselves, we're relying on the Go implementation
// using the last occuring instance of the key, which is the behavior
// as of Go 1.8.1.
v, err := builtinJSONUnmarshal(ast.String(s))
if err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
o, ok := v.(ast.Object)
if !ok {
return nil, errors.New("decoded JSON type was not an Object")
}
return o, nil
}
// getInputSha returns the SHA256 checksum of the input
func getInputSHA(input []byte) (hash []byte) {
hasher := sha256.New()
hasher.Write(input)
return hasher.Sum(nil)
}
func init() {
RegisterFunctionalBuiltin1(ast.JWTDecode.Name, builtinJWTDecode)
RegisterFunctionalBuiltin2(ast.JWTVerifyRS256.Name, builtinJWTVerifyRS256)
}