forked from slackhq/nebula
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathca_pool.go
296 lines (246 loc) · 7.9 KB
/
ca_pool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
package cert
import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
)
type CAPool struct {
CAs map[string]*CachedCertificate
certBlocklist map[string]struct{}
}
// NewCAPool creates an empty CAPool
func NewCAPool() *CAPool {
ca := CAPool{
CAs: make(map[string]*CachedCertificate),
certBlocklist: make(map[string]struct{}),
}
return &ca
}
// NewCAPoolFromPEM will create a new CA pool from the provided
// input bytes, which must be a PEM-encoded set of nebula certificates.
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
pool := NewCAPool()
var err error
var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
}
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
if expired {
return pool, ErrExpired
}
return pool, nil
}
// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
// Only the first pem encoded object will be consumed, any remaining bytes are returned.
// Parsed certificates will be verified and must be a CA
func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
if err != nil {
return pemBytes, err
}
err = ncp.AddCA(c)
if err != nil {
return pemBytes, err
}
return pemBytes, nil
}
// AddCA verifies a Nebula CA certificate and adds it to the pool.
func (ncp *CAPool) AddCA(c Certificate) error {
if !c.IsCA() {
return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
}
if !c.CheckSignature(c.PublicKey()) {
return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
}
sum, err := c.Fingerprint()
if err != nil {
return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
}
cc := &CachedCertificate{
Certificate: c,
Fingerprint: sum,
InvertedGroups: make(map[string]struct{}),
}
for _, g := range c.Groups() {
cc.InvertedGroups[g] = struct{}{}
}
ncp.CAs[sum] = cc
if c.Expired(time.Now()) {
return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
}
return nil
}
// BlocklistFingerprint adds a cert fingerprint to the blocklist
func (ncp *CAPool) BlocklistFingerprint(f string) {
ncp.certBlocklist[f] = struct{}{}
}
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
func (ncp *CAPool) ResetCertBlocklist() {
ncp.certBlocklist = make(map[string]struct{})
}
// IsBlocklisted tests the provided fingerprint against the pools blocklist.
// Returns true if the fingerprint is blocked.
func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
if _, ok := ncp.certBlocklist[fingerprint]; ok {
return true
}
return false
}
// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
// to increase performance.
func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
if c == nil {
return nil, fmt.Errorf("no certificate")
}
fp, err := c.Fingerprint()
if err != nil {
return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
}
signer, err := ncp.verify(c, now, fp, "")
if err != nil {
return nil, err
}
cc := CachedCertificate{
Certificate: c,
InvertedGroups: make(map[string]struct{}),
Fingerprint: fp,
signerFingerprint: signer.Fingerprint,
}
for _, g := range c.Groups() {
cc.InvertedGroups[g] = struct{}{}
}
return &cc, nil
}
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
// is a cheaper operation to perform as a result.
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
return err
}
func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
if ncp.IsBlocklisted(certFp) {
return nil, ErrBlockListed
}
signer, err := ncp.GetCAForCert(c)
if err != nil {
return nil, err
}
if signer.Certificate.Expired(now) {
return nil, ErrRootExpired
}
if c.Expired(now) {
return nil, ErrExpired
}
// If we are checking a cached certificate then we can bail early here
// Either the root is no longer trusted or everything is fine
if len(signerFp) > 0 {
if signerFp != signer.Fingerprint {
return nil, ErrFingerprintMismatch
}
return signer, nil
}
if !c.CheckSignature(signer.Certificate.PublicKey()) {
return nil, ErrSignatureMismatch
}
err = CheckCAConstraints(signer.Certificate, c)
if err != nil {
return nil, err
}
return signer, nil
}
// GetCAForCert attempts to return the signing certificate for the provided certificate.
// No signature validation is performed
func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, fmt.Errorf("no issuer in certificate")
}
signer, ok := ncp.CAs[issuer]
if ok {
return signer, nil
}
return nil, fmt.Errorf("could not find ca for the certificate")
}
// GetFingerprints returns an array of trusted CA fingerprints
func (ncp *CAPool) GetFingerprints() []string {
fp := make([]string, len(ncp.CAs))
i := 0
for k := range ncp.CAs {
fp[i] = k
i++
}
return fp
}
// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
func CheckCAConstraints(signer Certificate, sub Certificate) error {
return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
}
// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
// Make sure this cert isn't valid after the root
if notAfter.After(signer.NotAfter()) {
return fmt.Errorf("certificate expires after signing certificate")
}
// Make sure this cert wasn't valid before the root
if notBefore.Before(signer.NotBefore()) {
return fmt.Errorf("certificate is valid before the signing certificate")
}
// If the signer has a limited set of groups make sure the cert only contains a subset
signerGroups := signer.Groups()
if len(signerGroups) > 0 {
for _, g := range groups {
if !slices.Contains(signerGroups, g) {
return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
}
}
}
// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
signingNetworks := signer.Networks()
if len(signingNetworks) > 0 {
for _, certNetwork := range networks {
found := false
for _, signingNetwork := range signingNetworks {
if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
found = true
break
}
}
if !found {
return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
}
}
}
// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
signingUnsafeNetworks := signer.UnsafeNetworks()
if len(signingUnsafeNetworks) > 0 {
for _, certUnsafeNetwork := range unsafeNetworks {
found := false
for _, caNetwork := range signingUnsafeNetworks {
if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
found = true
break
}
}
if !found {
return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
}
}
}
return nil
}