package main import ( "bytes" "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" "log" "net/http" "os" "os/signal" "path/filepath" "sort" "strings" "sync" "syscall" "github.com/alecthomas/kingpin/v2" cfocsp "github.com/cloudflare/cfssl/ocsp" "github.com/prometheus/client_golang/prometheus/promhttp" "ocspcrl/internal/metrics" "ocspcrl/internal/ocsp_source" ) const ( crlFileName = "crl.pem" keyFileName = "key.pem" responderFileName = "responder.crt" caFileName = "ca.crt" ) type configuration struct { casDirectory string applicationListenAddress string metricsListenAddress string } func parseConfiguration(args []string) *configuration { config := &configuration{} app := kingpin.New("ocspcrl", "OCSP responder / CRL server (multi-CA)") app.HelpFlag.Short('h') app.Flag("cas-directory", "Path to a directory containing one subdirectory per CA. Each subdirectory must contain ca.crt, responder.crt, key.pem and crl.pem. The subdirectory name is used as the route prefix."). Envar("CAS_DIRECTORY").Required().ExistingDirVar(&config.casDirectory) app.Flag("web.listen-address", "Address for application endpoint"). Envar("WEB_LISTEN_ADDRESS").Default(":8080").StringVar(&config.applicationListenAddress) app.Flag("metrics.listen-address", "Address for metrics endpoint"). Envar("METRICS_LISTEN_ADDRESS").Default("[::1]:8081").StringVar(&config.metricsListenAddress) kingpin.MustParse(app.Parse(args)) return config } func decodeCrlBytes(content []byte) ([]byte, error) { if !bytes.Contains(content, []byte("BEGIN")) { return content, nil } block, rest := pem.Decode(content) if block == nil { return nil, fmt.Errorf("crl pem block could not be decoded") } if len(bytes.TrimSpace(rest)) > 0 { return nil, fmt.Errorf("crl file contains trailing data") } return block.Bytes, nil } func loadCrlFromFile(path string) (*x509.RevocationList, error) { content, readError := os.ReadFile(path) if readError != nil { return nil, readError } derBytes, decodeError := decodeCrlBytes(content) if decodeError != nil { return nil, fmt.Errorf("%s: %w", path, decodeError) } return x509.ParseRevocationList(derBytes) } func loadCertificateFromFile(path string) (*x509.Certificate, error) { content, readError := os.ReadFile(path) if readError != nil { return nil, readError } block, rest := pem.Decode(content) if block == nil { return nil, fmt.Errorf("%s: certificate pem block could not be decoded", path) } if len(bytes.TrimSpace(rest)) > 0 { return nil, fmt.Errorf("%s: certificate file contains trailing data", path) } return x509.ParseCertificate(block.Bytes) } type caInstance struct { name string crlPath string caCertificate *x509.Certificate source *ocsp_source.CrlSource crlMutex sync.RWMutex crl *x509.RevocationList } func (c *caInstance) reloadCrl() error { crl, loadError := loadCrlFromFile(c.crlPath) if loadError != nil { return loadError } c.crlMutex.Lock() c.crl = crl c.crlMutex.Unlock() metrics.CrlEntries.WithLabelValues(c.name).Set(float64(len(crl.RevokedCertificateEntries))) c.source.UseCrl(*crl) return nil } func (c *caInstance) currentCrl() *x509.RevocationList { c.crlMutex.RLock() defer c.crlMutex.RUnlock() return c.crl } type caFiles struct { key string responder string ca string crl string } func newCaFiles(directory string) caFiles { return caFiles{ key: filepath.Join(directory, keyFileName), responder: filepath.Join(directory, responderFileName), ca: filepath.Join(directory, caFileName), crl: filepath.Join(directory, crlFileName), } } func (f caFiles) ensureExist(caName string) error { for _, path := range []string{f.key, f.responder, f.ca, f.crl} { if _, statError := os.Stat(path); statError != nil { return fmt.Errorf("ca %q: %w", caName, statError) } } return nil } func verifyResponderIssuedByCa(responder tls.Certificate, ca *x509.Certificate) error { if responder.Leaf == nil { return fmt.Errorf("responder leaf certificate could not be parsed") } if !bytes.Equal(ca.RawSubject, responder.Leaf.RawIssuer) { return fmt.Errorf("responder certificate issuer does not match ca certificate subject; %+q != %+q", ca.Subject.String(), responder.Leaf.Issuer.String()) } return nil } func loadCa(name, directory string) (*caInstance, error) { files := newCaFiles(directory) if existsError := files.ensureExist(name); existsError != nil { return nil, existsError } responderKeyPair, loadResponderError := tls.LoadX509KeyPair(files.responder, files.key) if loadResponderError != nil { return nil, fmt.Errorf("ca %q: failed to load responder key pair: %w", name, loadResponderError) } caCertificate, loadCaError := loadCertificateFromFile(files.ca) if loadCaError != nil { return nil, fmt.Errorf("ca %q: failed to load ca certificate: %w", name, loadCaError) } if verifyError := verifyResponderIssuedByCa(responderKeyPair, caCertificate); verifyError != nil { return nil, fmt.Errorf("ca %q: %w", name, verifyError) } instance := &caInstance{ name: name, crlPath: files.crl, caCertificate: caCertificate, source: ocsp_source.NewCrlSource(caCertificate, responderKeyPair), } if reloadError := instance.reloadCrl(); reloadError != nil { return nil, fmt.Errorf("ca %q: failed to load crl: %w", name, reloadError) } return instance, nil } func listCaSubdirectories(rootDir string) ([]string, error) { entries, readDirError := os.ReadDir(rootDir) if readDirError != nil { return nil, fmt.Errorf("failed to read cas directory: %w", readDirError) } names := []string{} for _, entry := range entries { if !entry.IsDir() { continue } if strings.HasPrefix(entry.Name(), ".") { continue } names = append(names, entry.Name()) } sort.Strings(names) return names, nil } func discoverCas(rootDir string) ([]*caInstance, error) { names, listError := listCaSubdirectories(rootDir) if listError != nil { return nil, listError } if len(names) == 0 { return nil, fmt.Errorf("no ca subdirectories found in %s", rootDir) } cas := make([]*caInstance, 0, len(names)) for _, name := range names { instance, loadError := loadCa(name, filepath.Join(rootDir, name)) if loadError != nil { return nil, loadError } cas = append(cas, instance) } return cas, nil } func writeBinary(w http.ResponseWriter, contentType string, body []byte) { w.Header().Set("Content-Type", contentType) w.Write(body) } func writePem(w http.ResponseWriter, contentType, blockType string, body []byte) { w.Header().Set("Content-Type", contentType) pem.Encode(w, &pem.Block{Type: blockType, Bytes: body}) } func registerOcspRoutes(router *http.ServeMux, prefix string, ca *caInstance) { responder := cfocsp.NewResponder(ca.source, nil) router.Handle(prefix+"/ocsp", responder) router.Handle(prefix+"/ocsp/", http.StripPrefix(prefix+"/ocsp/", responder)) } func registerCrlRoutes(router *http.ServeMux, prefix string, ca *caInstance) { router.HandleFunc(prefix+"/crl", func(w http.ResponseWriter, r *http.Request) { writeBinary(w, "application/pkix-cert", ca.currentCrl().Raw) }) router.HandleFunc(prefix+"/crl.pem", func(w http.ResponseWriter, r *http.Request) { writePem(w, "application/pkix-crl", "X509 CRL", ca.currentCrl().Raw) }) } func registerCaCertificateRoutes(router *http.ServeMux, prefix string, ca *caInstance) { router.HandleFunc(prefix+"/ca", func(w http.ResponseWriter, r *http.Request) { writeBinary(w, "application/pkix-cert", ca.caCertificate.Raw) }) router.HandleFunc(prefix+"/ca.pem", func(w http.ResponseWriter, r *http.Request) { writePem(w, "application/x-x509-ca-cert", "CERTIFICATE", ca.caCertificate.Raw) }) } func registerCaRoutes(router *http.ServeMux, ca *caInstance) { prefix := "/" + ca.name registerOcspRoutes(router, prefix, ca) registerCrlRoutes(router, prefix, ca) registerCaCertificateRoutes(router, prefix, ca) } func buildApplicationRouter(cas []*caInstance) *http.ServeMux { router := http.NewServeMux() for _, ca := range cas { registerCaRoutes(router, ca) log.Printf("registered ca %q with routes under /%s/", ca.name, ca.name) } return router } func reloadAllCrls(cas []*caInstance) { for _, ca := range cas { if reloadError := ca.reloadCrl(); reloadError != nil { log.Printf("failed to reload crl for ca %q: %v", ca.name, reloadError) } else { log.Printf("reloaded crl for ca %q", ca.name) } } } func runReloadWorker(signalChan <-chan os.Signal, cas []*caInstance) { defer log.Println("reload crl worker stopped") for { _, ok := <-signalChan if !ok { return } reloadAllCrls(cas) } } func startServer(server *http.Server, label string) <-chan struct{} { closed := make(chan struct{}) go func() { log.Printf("starting %s server on %+q", label, server.Addr) if listenError := server.ListenAndServe(); !errors.Is(listenError, http.ErrServerClosed) { log.Printf("%s server error: %v", label, listenError) } close(closed) }() return closed } func main() { log.SetFlags(log.Lshortfile) config := parseConfiguration(os.Args[1:]) cas, discoverError := discoverCas(config.casDirectory) if discoverError != nil { log.Fatalf("failed to load cas: %v", discoverError) } applicationRouter := buildApplicationRouter(cas) terminationChan := make(chan os.Signal, 1) signal.Notify(terminationChan, syscall.SIGINT, syscall.SIGTERM) hupChan := make(chan os.Signal, 1) signal.Notify(hupChan, syscall.SIGHUP) go runReloadWorker(hupChan, cas) applicationServer := &http.Server{Addr: config.applicationListenAddress, Handler: metrics.Middleware(applicationRouter)} metricsServer := &http.Server{Addr: config.metricsListenAddress, Handler: promhttp.Handler()} applicationServerClosed := startServer(applicationServer, "application") metricsServerClosed := startServer(metricsServer, "metrics") <-terminationChan close(hupChan) applicationServer.Shutdown(nil) metricsServer.Shutdown(nil) <-applicationServerClosed <-metricsServerClosed }