grimoire/go/linksmith.go

394 lines
9.6 KiB
Go

package main
import (
"crypto/x509"
_ "embed"
"encoding/base64"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"regexp"
"runtime"
"strings"
)
type Certificate struct {
Store *x509.Certificate
}
func LoadCertFromDER(der []byte) (Certificate, error) {
cert := Certificate { Store: nil }
store, err := x509.ParseCertificate(der)
if err != nil {
return cert, err
}
cert.Store = store
return cert, nil
}
func LoadCertFromPEM(pem string) (Certificate, error) {
cert := Certificate { Store: nil }
if ! strings.HasPrefix(pem, "-----BEGIN CERTIFICATE-----") {
return cert, errors.New("invalid certificate header")
} else if ! strings.HasSuffix(pem, "-----END CERTIFICATE-----") {
return cert, errors.New("invalid certificate footer")
}
encoded := strings.Split(pem, "\n")
data := strings.Join(encoded[1:len(encoded) - 1], "\n")
bytes, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return cert, err
}
store, err := x509.ParseCertificate(bytes)
if err != nil {
return cert, err
}
cert.Store = store
return cert, nil
}
func (cert Certificate) ToPEM() string {
pem_export := "-----BEGIN CERTIFICATE-----\n"
encoded := base64.StdEncoding.EncodeToString(cert.Store.Raw)
idx := 0
for idx < len(encoded) {
if idx + 64 < len(encoded) {
pem_export += encoded[idx:idx + 64] + "\n"
} else {
pem_export += encoded[idx:] + "\n"
}
idx += 64
}
pem_export += "-----END CERTIFICATE-----"
return pem_export
}
func (cert Certificate) IsSelfSigned() bool {
err := cert.Store.CheckSignatureFrom(cert.Store)
return err == nil
}
func (cert Certificate) IsRoot() bool {
return cert.IsSelfSigned() && cert.Store.IsCA
}
func (cert Certificate) IsIntermediate() bool {
err := cert.Store.CheckSignatureFrom(cert.Store)
return err == nil && cert.Store.IsCA
}
func (cert Certificate) IsLeaf() bool {
if cert.Store.IsCA {
return false
}
allowed_uses := []x509.ExtKeyUsage {
x509.ExtKeyUsageAny, x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
}
for _, use := range cert.Store.ExtKeyUsage {
good_use := false
for _, allowed := range allowed_uses {
if use == allowed {
good_use = true
}
}
if ! good_use {
return false
}
}
return true
}
func (cert Certificate) GetHTTPAuthorityURL() (string, bool) {
for _, url := range cert.Store.IssuingCertificateURL {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return url, true
}
}
return "", false
}
func (cert Certificate) Parent(parent Certificate) bool {
return cert.Store.CheckSignatureFrom(parent.Store) == nil
}
type CertificateChain struct {
Certificates []Certificate
Leaf, Root *Certificate
}
func LoadCertsFromPEM(pem string) (CertificateChain, error) {
chain := CertificateChain {
Certificates: []Certificate { }, Leaf: nil, Root: nil,
}
parser, err := regexp.Compile("-----BEGIN CERTIFICATE-----\r?\n([A-Za-z0-9/+=\r\n]+)\r?\n-----END CERTIFICATE-----")
if err != nil {
return chain, err
}
matches := parser.FindAllStringSubmatch(string(pem), -1)
for _, match := range matches {
cert, err := LoadCertFromPEM(match[0])
if err != nil {
return chain, err
}
chain.Certificates = append(chain.Certificates, cert)
chain.Root = &chain.Certificates[len(chain.Certificates) - 1]
}
if len(chain.Certificates) == 0 {
return chain, errors.New("no certificates found")
}
chain.Leaf = &chain.Certificates[0]
return chain, nil
}
//go:embed roots.pem
var BAKED_IN_ROOTS string
func LoadSystemStore() (CertificateChain, error) {
cert_path := ""
if path, has := os.LookupEnv("SSL_CERT_FILE"); has {
cert_path = path
} else if runtime.GOOS == "linux" {
cert_path = "/etc/ssl/certs/ca-certificates.crt"
} else if runtime.GOOS == "darwin" {
cert_path = "/etc/ssl/cert.pem"
}
if cert_path == "" {
return LoadCertsFromPEM(BAKED_IN_ROOTS)
}
pem, err := os.ReadFile(cert_path)
if err != nil {
return CertificateChain { }, err
}
return LoadCertsFromPEM(string(pem))
}
func (chain CertificateChain) ToPEM() string {
pem_export := ""
for _, link := range chain.Certificates {
pem_export += link.ToPEM() + "\n"
}
return pem_export
}
func (chain *CertificateChain) Prepend(cert Certificate) {
chain.Certificates = append([]Certificate { cert }, chain.Certificates...)
chain.Leaf = &chain.Certificates[0]
}
func (chain *CertificateChain) Append(cert Certificate) {
chain.Certificates = append(chain.Certificates, cert)
chain.Root = &chain.Certificates[len(chain.Certificates) - 1]
}
func (chain CertificateChain) IsLinked() bool {
for idx, link := range chain.Certificates[1:] {
err := chain.Certificates[idx].Store.CheckSignatureFrom(link.Store)
if err != nil {
return false
}
}
return true
}
func (chain CertificateChain) IsComplete(test_links ...bool) bool {
if len(test_links) > 0 && test_links[0] {
return chain.Root.IsRoot() && chain.Leaf.IsLeaf() && chain.IsLinked()
} else {
return chain.Root.IsRoot() && chain.Leaf.IsLeaf()
}
}
func (chain CertificateChain) Links() int {
return len(chain.Certificates)
}
func main() {
has_sys_pool := true
sys_pool, err := LoadSystemStore()
if err != nil {
fmt.Fprintf(os.Stderr, "error: unable to load system pool: %s\n", err)
has_sys_pool = false
}
flag.Parse()
args := flag.Args()
for _, arg := range args {
pem, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
continue
}
chain, err := LoadCertsFromPEM(string(pem))
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
continue
} else if chain.Leaf.IsSelfSigned() {
fmt.Fprintf(os.Stderr, "%s: certificate is self-signed\n", arg)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
continue
} else if ! chain.IsLinked() {
fmt.Fprintf(os.Stderr, "%s: error: certificates do not form a chain\n", arg)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
continue
}
for ! chain.Root.IsRoot() {
found_link := false
aia_url, found_url := chain.Root.GetHTTPAuthorityURL()
if found_url {
res, err := http.Get(aia_url)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
if ! has_sys_pool {
fmt.Fprintf(os.Stderr, "%s: attempting alternate chain discovery methods\n", arg)
}
} else {
cert_data, err := io.ReadAll(res.Body)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
if ! has_sys_pool {
fmt.Fprintf(os.Stderr, "%s: attempting alternate chain discovery methods\n", arg)
}
} else {
cert := Certificate { Store: nil }
mime := "der"
switch res.Header.Get("content-type") {
case "application/pkix-cer": mime = "der"
case "application/pkix-cert": mime = "der"
case "application/x-x509-ca-cert": mime = "der"
case "application/x-pem-file": mime = "pem"
}
if strings.HasSuffix(aia_url, ".der") || mime == "der" {
link, err := LoadCertFromDER(cert_data)
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
if ! has_sys_pool {
fmt.Fprintf(os.Stderr, "%s: attempting alternate chain discovery methods\n", arg)
}
} else {
cert = link
found_link = true
}
} else if strings.HasSuffix(aia_url, ".pem") || mime == "pem" {
link, err := LoadCertFromPEM(string(cert_data))
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
if ! has_sys_pool {
fmt.Fprintf(os.Stderr, "%s: attempting alternate chain discovery methods\n", arg)
}
} else {
cert = link
found_link = true
}
}
if found_link {
chain.Append(cert)
} else {
fmt.Fprintf(os.Stderr, "%s: unable to determine link file type\n", arg)
if ! has_sys_pool {
fmt.Fprintf(os.Stderr, "%s: attempting alternate chain discovery methods\n", arg)
}
}
}
res.Body.Close()
}
}
if ! found_link && has_sys_pool {
for _, sys_cert := range sys_pool.Certificates {
if chain.Root.Parent(sys_cert) {
chain.Append(sys_cert)
found_link = true
break
}
}
}
if ! found_link {
break
}
}
if chain.IsComplete(true) {
file, err := ioutil.TempFile("", "*-chain.pem")
if err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
} else {
if _, err := file.Write([]byte(chain.ToPEM())); err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %s\n", arg, err)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
} else {
fmt.Printf("%s: complete chain has been saved to %s\n", arg, file.Name())
}
file.Close()
}
} else {
fmt.Fprintf(os.Stderr, "%s: unable to determine next link in chain\n", arg)
fmt.Fprintf(os.Stderr, "%s: skipping file...\n", arg)
}
}
}