refactor: Split large main file

This commit is contained in:
T. R. Bernstein
2026-04-30 02:09:16 +02:00
parent 4e4dd2eaae
commit 853e2a909f
7 changed files with 332 additions and 291 deletions

294
main.go
View File

@@ -1,307 +1,19 @@
package main
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"context"
"errors"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"sort"
"strings"
"sync"
"syscall"
"github.com/alecthomas/kingpin/v2"
cfocsp "github.com/cloudflare/cfssl/ocsp"
"github.com/prometheus/client_golang/prometheus/promhttp"
"ocspcrl/internal/metrics"
"ocspcrl/internal/ocsp_source"
)
const (
crlFileName = "crl.pem"
keyFileName = "key.pem"
responderFileName = "responder.crt"
caFileName = "ca.crt"
)
type configuration struct {
casDirectory string
applicationListenAddress string
metricsListenAddress string
}
func parseConfiguration(args []string) *configuration {
config := &configuration{}
app := kingpin.New("ocspcrl", "OCSP responder / CRL server (multi-CA)")
app.HelpFlag.Short('h')
app.Flag("cas-directory", "Path to a directory containing one subdirectory per CA. Each subdirectory must contain ca.crt, responder.crt, key.pem and crl.pem. The subdirectory name is used as the route prefix.").
Envar("CAS_DIRECTORY").Required().ExistingDirVar(&config.casDirectory)
app.Flag("web.listen-address", "Address for application endpoint").
Envar("WEB_LISTEN_ADDRESS").Default(":8080").StringVar(&config.applicationListenAddress)
app.Flag("metrics.listen-address", "Address for metrics endpoint").
Envar("METRICS_LISTEN_ADDRESS").Default("[::1]:8081").StringVar(&config.metricsListenAddress)
kingpin.MustParse(app.Parse(args))
return config
}
func decodeCrlBytes(content []byte) ([]byte, error) {
if !bytes.Contains(content, []byte("BEGIN")) {
return content, nil
}
block, rest := pem.Decode(content)
if block == nil {
return nil, fmt.Errorf("crl pem block could not be decoded")
}
if len(bytes.TrimSpace(rest)) > 0 {
return nil, fmt.Errorf("crl file contains trailing data")
}
return block.Bytes, nil
}
func loadCrlFromFile(path string) (*x509.RevocationList, error) {
content, readError := os.ReadFile(path)
if readError != nil {
return nil, readError
}
derBytes, decodeError := decodeCrlBytes(content)
if decodeError != nil {
return nil, fmt.Errorf("%s: %w", path, decodeError)
}
return x509.ParseRevocationList(derBytes)
}
func loadCertificateFromFile(path string) (*x509.Certificate, error) {
content, readError := os.ReadFile(path)
if readError != nil {
return nil, readError
}
block, rest := pem.Decode(content)
if block == nil {
return nil, fmt.Errorf("%s: certificate pem block could not be decoded", path)
}
if len(bytes.TrimSpace(rest)) > 0 {
return nil, fmt.Errorf("%s: certificate file contains trailing data", path)
}
return x509.ParseCertificate(block.Bytes)
}
type caInstance struct {
name string
crlPath string
caCertificate *x509.Certificate
source *ocsp_source.CrlSource
crlMutex sync.RWMutex
crl *x509.RevocationList
}
func (c *caInstance) reloadCrl() error {
crl, loadError := loadCrlFromFile(c.crlPath)
if loadError != nil {
return loadError
}
c.crlMutex.Lock()
c.crl = crl
c.crlMutex.Unlock()
metrics.CrlEntries.WithLabelValues(c.name).Set(float64(len(crl.RevokedCertificateEntries)))
c.source.UseCrl(*crl)
return nil
}
func (c *caInstance) currentCrl() *x509.RevocationList {
c.crlMutex.RLock()
defer c.crlMutex.RUnlock()
return c.crl
}
type caFiles struct {
key string
responder string
ca string
crl string
}
func newCaFiles(directory string) caFiles {
return caFiles{
key: filepath.Join(directory, keyFileName),
responder: filepath.Join(directory, responderFileName),
ca: filepath.Join(directory, caFileName),
crl: filepath.Join(directory, crlFileName),
}
}
func (f caFiles) ensureExist(caName string) error {
for _, path := range []string{f.key, f.responder, f.ca, f.crl} {
if _, statError := os.Stat(path); statError != nil {
return fmt.Errorf("ca %q: %w", caName, statError)
}
}
return nil
}
func verifyResponderIssuedByCa(responder tls.Certificate, ca *x509.Certificate) error {
if responder.Leaf == nil {
return fmt.Errorf("responder leaf certificate could not be parsed")
}
if !bytes.Equal(ca.RawSubject, responder.Leaf.RawIssuer) {
return fmt.Errorf("responder certificate issuer does not match ca certificate subject; %+q != %+q",
ca.Subject.String(), responder.Leaf.Issuer.String())
}
return nil
}
func loadCa(name, directory string) (*caInstance, error) {
files := newCaFiles(directory)
if existsError := files.ensureExist(name); existsError != nil {
return nil, existsError
}
responderKeyPair, loadResponderError := tls.LoadX509KeyPair(files.responder, files.key)
if loadResponderError != nil {
return nil, fmt.Errorf("ca %q: failed to load responder key pair: %w", name, loadResponderError)
}
caCertificate, loadCaError := loadCertificateFromFile(files.ca)
if loadCaError != nil {
return nil, fmt.Errorf("ca %q: failed to load ca certificate: %w", name, loadCaError)
}
if verifyError := verifyResponderIssuedByCa(responderKeyPair, caCertificate); verifyError != nil {
return nil, fmt.Errorf("ca %q: %w", name, verifyError)
}
instance := &caInstance{
name: name,
crlPath: files.crl,
caCertificate: caCertificate,
source: ocsp_source.NewCrlSource(caCertificate, responderKeyPair),
}
if reloadError := instance.reloadCrl(); reloadError != nil {
return nil, fmt.Errorf("ca %q: failed to load crl: %w", name, reloadError)
}
return instance, nil
}
func listCaSubdirectories(rootDir string) ([]string, error) {
entries, readDirError := os.ReadDir(rootDir)
if readDirError != nil {
return nil, fmt.Errorf("failed to read cas directory: %w", readDirError)
}
names := []string{}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
if strings.HasPrefix(entry.Name(), ".") {
continue
}
names = append(names, entry.Name())
}
sort.Strings(names)
return names, nil
}
func discoverCas(rootDir string) ([]*caInstance, error) {
names, listError := listCaSubdirectories(rootDir)
if listError != nil {
return nil, listError
}
if len(names) == 0 {
return nil, fmt.Errorf("no ca subdirectories found in %s", rootDir)
}
cas := make([]*caInstance, 0, len(names))
for _, name := range names {
instance, loadError := loadCa(name, filepath.Join(rootDir, name))
if loadError != nil {
return nil, loadError
}
cas = append(cas, instance)
}
return cas, nil
}
func writeBinary(w http.ResponseWriter, contentType string, body []byte) {
w.Header().Set("Content-Type", contentType)
w.Write(body)
}
func writePem(w http.ResponseWriter, contentType, blockType string, body []byte) {
w.Header().Set("Content-Type", contentType)
pem.Encode(w, &pem.Block{Type: blockType, Bytes: body})
}
func registerOcspRoutes(router *http.ServeMux, prefix string, ca *caInstance) {
responder := cfocsp.NewResponder(ca.source, nil)
router.Handle(prefix+"/ocsp", responder)
router.Handle(prefix+"/ocsp/", http.StripPrefix(prefix+"/ocsp/", responder))
}
func registerCrlRoutes(router *http.ServeMux, prefix string, ca *caInstance) {
router.HandleFunc(prefix+"/crl", func(w http.ResponseWriter, r *http.Request) {
writeBinary(w, "application/pkix-cert", ca.currentCrl().Raw)
})
router.HandleFunc(prefix+"/crl.pem", func(w http.ResponseWriter, r *http.Request) {
writePem(w, "application/pkix-crl", "X509 CRL", ca.currentCrl().Raw)
})
}
func registerCaCertificateRoutes(router *http.ServeMux, prefix string, ca *caInstance) {
router.HandleFunc(prefix+"/ca", func(w http.ResponseWriter, r *http.Request) {
writeBinary(w, "application/pkix-cert", ca.caCertificate.Raw)
})
router.HandleFunc(prefix+"/ca.pem", func(w http.ResponseWriter, r *http.Request) {
writePem(w, "application/x-x509-ca-cert", "CERTIFICATE", ca.caCertificate.Raw)
})
}
func registerCaRoutes(router *http.ServeMux, ca *caInstance) {
prefix := "/" + ca.name
registerOcspRoutes(router, prefix, ca)
registerCrlRoutes(router, prefix, ca)
registerCaCertificateRoutes(router, prefix, ca)
}
func buildApplicationRouter(cas []*caInstance) *http.ServeMux {
router := http.NewServeMux()
for _, ca := range cas {
registerCaRoutes(router, ca)
log.Printf("registered ca %q with routes under /%s/", ca.name, ca.name)
}
return router
}
func reloadAllCrls(cas []*caInstance) {
for _, ca := range cas {
if reloadError := ca.reloadCrl(); reloadError != nil {
log.Printf("failed to reload crl for ca %q: %v", ca.name, reloadError)
} else {
log.Printf("reloaded crl for ca %q", ca.name)
}
}
}
func runReloadWorker(signalChan <-chan os.Signal, cas []*caInstance) {
defer log.Println("reload crl worker stopped")
for {
_, ok := <-signalChan
if !ok {
return
}
reloadAllCrls(cas)
}
}
func startServer(server *http.Server, label string) <-chan struct{} {
closed := make(chan struct{})
go func() {
@@ -340,8 +52,8 @@ func main() {
<-terminationChan
close(hupChan)
applicationServer.Shutdown(nil)
metricsServer.Shutdown(nil)
applicationServer.Shutdown(context.TODO())
metricsServer.Shutdown(context.TODO())
<-applicationServerClosed
<-metricsServerClosed
}