Skip to content

Commit

Permalink
azure: add support for Azure Private Zone DNS (go-acme#1561)
Browse files Browse the repository at this point in the history
Co-authored-by: Fernandez Ludovic <[email protected]>
  • Loading branch information
knutejoh and ldez authored Jan 17, 2022
1 parent 27bb3f2 commit 6dd5d1f
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 111 deletions.
1 change: 1 addition & 0 deletions cmd/zz_gen_cmd_dnshelp.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ func displayDNSHelp(name string) error {
ew.writeln(`Additional Configuration:`)
ew.writeln(` - "AZURE_METADATA_ENDPOINT": Metadata Service endpoint URL`)
ew.writeln(` - "AZURE_POLLING_INTERVAL": Time between DNS propagation check`)
ew.writeln(` - "AZURE_PRIVATE_ZONE": Set to true to use Azure Private DNS Zones and not public`)
ew.writeln(` - "AZURE_PROPAGATION_TIMEOUT": Maximum waiting time for DNS propagation`)
ew.writeln(` - "AZURE_TTL": The TTL of the TXT record used for the DNS challenge`)
ew.writeln(` - "AZURE_ZONE_NAME": Zone name to use inside Azure DNS service to add the TXT record in`)
Expand Down
1 change: 1 addition & 0 deletions docs/content/dns/zz_gen_azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ More information [here](/lego/dns/#configuration-and-credentials).
|--------------------------------|-------------|
| `AZURE_METADATA_ENDPOINT` | Metadata Service endpoint URL |
| `AZURE_POLLING_INTERVAL` | Time between DNS propagation check |
| `AZURE_PRIVATE_ZONE` | Set to true to use Azure Private DNS Zones and not public |
| `AZURE_PROPAGATION_TIMEOUT` | Maximum waiting time for DNS propagation |
| `AZURE_TTL` | The TTL of the TXT record used for the DNS challenge |
| `AZURE_ZONE_NAME` | Zone name to use inside Azure DNS service to add the TXT record in |
Expand Down
111 changes: 13 additions & 98 deletions providers/dns/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
package azure

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2017-09-01/dns"
"github.com/Azure/go-autorest/autorest"
aazure "github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/Azure/go-autorest/autorest/to"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env"
)
Expand All @@ -34,6 +32,7 @@ const (
EnvClientID = envNamespace + "CLIENT_ID"
EnvClientSecret = envNamespace + "CLIENT_SECRET"
EnvZoneName = envNamespace + "ZONE_NAME"
EnvPrivateZone = envNamespace + "PRIVATE_ZONE"

EnvTTL = envNamespace + "TTL"
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
Expand All @@ -49,6 +48,7 @@ type Config struct {

SubscriptionID string
ResourceGroup string
PrivateZone bool

MetadataEndpoint string
ResourceManagerEndpoint string
Expand All @@ -74,8 +74,7 @@ func NewDefaultConfig() *Config {

// DNSProvider implements the challenge.Provider interface.
type DNSProvider struct {
config *Config
authorizer autorest.Authorizer
provider challenge.ProviderTimeout
}

// NewDNSProvider returns a DNSProvider instance configured for azure.
Expand Down Expand Up @@ -113,6 +112,7 @@ func NewDNSProvider() (*DNSProvider, error) {
config.ClientSecret = env.GetOrFile(EnvClientSecret)
config.ClientID = env.GetOrFile(EnvClientID)
config.TenantID = env.GetOrFile(EnvTenantID)
config.PrivateZone = env.GetOrDefaultBool(EnvPrivateZone, false)

return NewDNSProviderConfig(config)
}
Expand Down Expand Up @@ -156,112 +156,27 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
config.ResourceGroup = resGroup
}

return &DNSProvider{config: config, authorizer: authorizer}, nil
if config.PrivateZone {
return &DNSProvider{provider: &dnsProviderPrivate{config: config, authorizer: authorizer}}, nil
}

return &DNSProvider{provider: &dnsProviderPublic{config: config, authorizer: authorizer}}, nil
}

// Timeout returns the timeout and interval to use when checking for DNS propagation.
// Adjusting here to cope with spikes in propagation times.
func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
return d.config.PropagationTimeout, d.config.PollingInterval
return d.provider.Timeout()
}

// Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
ctx := context.Background()
fqdn, value := dns01.GetRecord(domain, keyAuth)

zone, err := d.getHostedZoneID(ctx, fqdn)
if err != nil {
return fmt.Errorf("azure: %w", err)
}

rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)
rsc.Authorizer = d.authorizer

relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone))

// Get existing record set
rset, err := rsc.Get(ctx, d.config.ResourceGroup, zone, relative, dns.TXT)
if err != nil {
var detailed autorest.DetailedError
if !errors.As(err, &detailed) || detailed.StatusCode != http.StatusNotFound {
return fmt.Errorf("azure: %w", err)
}
}

// Construct unique TXT records using map
uniqRecords := map[string]struct{}{value: {}}
if rset.RecordSetProperties != nil && rset.TxtRecords != nil {
for _, txtRecord := range *rset.TxtRecords {
// Assume Value doesn't contain multiple strings
if txtRecord.Value != nil && len(*txtRecord.Value) > 0 {
uniqRecords[(*txtRecord.Value)[0]] = struct{}{}
}
}
}

var txtRecords []dns.TxtRecord
for txt := range uniqRecords {
txtRecords = append(txtRecords, dns.TxtRecord{Value: &[]string{txt}})
}

rec := dns.RecordSet{
Name: &relative,
RecordSetProperties: &dns.RecordSetProperties{
TTL: to.Int64Ptr(int64(d.config.TTL)),
TxtRecords: &txtRecords,
},
}

_, err = rsc.CreateOrUpdate(ctx, d.config.ResourceGroup, zone, relative, dns.TXT, rec, "", "")
if err != nil {
return fmt.Errorf("azure: %w", err)
}
return nil
return d.provider.Present(domain, token, keyAuth)
}

// CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
ctx := context.Background()
fqdn, _ := dns01.GetRecord(domain, keyAuth)

zone, err := d.getHostedZoneID(ctx, fqdn)
if err != nil {
return fmt.Errorf("azure: %w", err)
}

relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone))
rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)
rsc.Authorizer = d.authorizer

_, err = rsc.Delete(ctx, d.config.ResourceGroup, zone, relative, dns.TXT, "")
if err != nil {
return fmt.Errorf("azure: %w", err)
}
return nil
}

// Checks that azure has a zone for this domain name.
func (d *DNSProvider) getHostedZoneID(ctx context.Context, fqdn string) (string, error) {
if zone := env.GetOrFile(EnvZoneName); zone != "" {
return zone, nil
}

authZone, err := dns01.FindZoneByFqdn(fqdn)
if err != nil {
return "", err
}

dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID)
dc.Authorizer = d.authorizer

zone, err := dc.Get(ctx, d.config.ResourceGroup, dns01.UnFqdn(authZone))
if err != nil {
return "", err
}

// zone.Name shouldn't have a trailing dot(.)
return to.String(zone.Name), nil
return d.provider.CleanUp(domain, token, keyAuth)
}

// Returns the relative record to the domain.
Expand Down
5 changes: 3 additions & 2 deletions providers/dns/azure/azure.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ Example = ''''''
AZURE_RESOURCE_GROUP = "Resource group"
'instance metadata service' = "If the credentials are **not** set via the environment, then it will attempt to get a bearer token via the [instance metadata service](https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service)."
[Configuration.Additional]
AZURE_METADATA_ENDPOINT = "Metadata Service endpoint URL"
AZURE_PRIVATE_ZONE = "Set to true to use Azure Private DNS Zones and not public"
AZURE_ZONE_NAME = "Zone name to use inside Azure DNS service to add the TXT record in"
AZURE_POLLING_INTERVAL = "Time between DNS propagation check"
AZURE_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation"
AZURE_TTL = "The TTL of the TXT record used for the DNS challenge"
AZURE_METADATA_ENDPOINT = "Metadata Service endpoint URL"
AZURE_ZONE_NAME = "Zone name to use inside Azure DNS service to add the TXT record in"

[Links]
API = "https://docs.microsoft.com/en-us/go/azure/"
Expand Down
45 changes: 34 additions & 11 deletions providers/dns/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/go-acme/lego/v4/platform/tester"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -59,13 +60,16 @@ func TestNewDNSProvider(t *testing.T) {

p, err := NewDNSProvider()

if test.expected == "" {
require.NoError(t, err)
require.NotNil(t, p)
require.NotNil(t, p.config)
} else {
if test.expected != "" {
require.EqualError(t, err, test.expected)
return
}

require.NoError(t, err)
require.NotNil(t, p)
require.NotNil(t, p.provider)

assert.IsType(t, p.provider, new(dnsProviderPublic))
})
}
}
Expand All @@ -78,16 +82,27 @@ func TestNewDNSProviderConfig(t *testing.T) {
subscriptionID string
tenantID string
resourceGroup string
privateZone bool
handler func(w http.ResponseWriter, r *http.Request)
expected string
}{
{
desc: "success",
desc: "success (public)",
clientID: "A",
clientSecret: "B",
tenantID: "C",
subscriptionID: "D",
resourceGroup: "E",
privateZone: false,
},
{
desc: "success (private)",
clientID: "A",
clientSecret: "B",
tenantID: "C",
subscriptionID: "D",
resourceGroup: "E",
privateZone: true,
},
{
desc: "SubscriptionID missing",
Expand Down Expand Up @@ -132,6 +147,7 @@ func TestNewDNSProviderConfig(t *testing.T) {
config.SubscriptionID = test.subscriptionID
config.TenantID = test.tenantID
config.ResourceGroup = test.resourceGroup
config.PrivateZone = test.privateZone

mux := http.NewServeMux()
server := httptest.NewServer(mux)
Expand All @@ -146,12 +162,19 @@ func TestNewDNSProviderConfig(t *testing.T) {

p, err := NewDNSProviderConfig(config)

if test.expected == "" {
require.NoError(t, err)
require.NotNil(t, p)
require.NotNil(t, p.config)
} else {
if test.expected != "" {
require.EqualError(t, err, test.expected)
return
}

require.NoError(t, err)
require.NotNil(t, p)
require.NotNil(t, p.provider)

if test.privateZone {
assert.IsType(t, p.provider, new(dnsProviderPrivate))
} else {
assert.IsType(t, p.provider, new(dnsProviderPublic))
}
})
}
Expand Down
Loading

0 comments on commit 6dd5d1f

Please sign in to comment.