Implement Get and wildcard match

Lots of changes and new API's here.

Fixes #7.
This commit is contained in:
Kevin Burke
2017-04-23 11:42:22 -07:00
parent c40e54d2bf
commit 67c39ca6b4
12 changed files with 502 additions and 26 deletions

View File

@@ -1,3 +1,7 @@
lint:
go vet ./...
staticcheck ./...
test:
go test -timeout=10ms ./...
test: lint
@# the timeout helps guard against infinite recursion
go test -timeout=30ms ./...

View File

@@ -1,14 +1,14 @@
# ssh_config
This is a Go parser for `ssh_config` files. Importantly, this parser attempts to
preserve comments, so you can manipulate a `ssh_config` file from a program, if
your heart wishes.
This is a Go parser for `ssh_config` files. Importantly, this parser attempts
to preserve comments in a given file, so you can manipulate a `ssh_config` file
from a program, if your heart desires.
Example usage:
```go
f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
cfg, _ := ssh_config.LoadReader(f)
cfg, _ := ssh_config.Decode(f)
for _, host := range cfg.Hosts {
fmt.Println("patterns:", host.Patterns)
for _, node := range host.Nodes {
@@ -20,6 +20,15 @@ for _, host := range cfg.Hosts {
fmt.Println(cfg.String())
```
This is very alpha software. In particular the most useful thing you want to do
with this is figure out the correct HostName and User for a particular host
pattern. There's no way to do this, currently.
The `ssh_config` program will attempt to read values from `$HOME/.ssh/config`,
falling back to `/etc/ssh/ssh_config`.
```go
port := ssh_config.Get("myhost", "Port")
```
## Donating
Donations free up time to make improvements to the library, and respond to
bug reports. You can send donations via Paypal's "Send Money" feature to
kev@inburke.com. Donations are not tax deductible in the USA.

237
config.go
View File

@@ -6,14 +6,116 @@ import (
"fmt"
"io"
"os"
osuser "os/user"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
)
type configFinder func() string
type UserSettings struct {
systemConfig *Config
systemConfigFinder configFinder
userConfig *Config
userConfigFinder configFinder
username string
loadConfigs sync.Once
onceErr error
IgnoreErrors bool
}
func userConfigFinder() string {
user, err := osuser.Current()
var home string
if err == nil {
home = user.HomeDir
} else {
home = os.Getenv("HOME")
}
return filepath.Join(home, ".ssh", "config")
}
var DefaultUserSettings = &UserSettings{
IgnoreErrors: false,
systemConfigFinder: systemConfigFinder,
userConfigFinder: userConfigFinder,
}
func systemConfigFinder() string {
return filepath.Join("/", "etc", "ssh", "ssh_config")
}
func findVal(c *Config, alias, key string) (string, error) {
if c == nil {
return "", nil
}
return c.Get(alias, key)
}
func Get(alias, key string) string {
return DefaultUserSettings.Get(alias, key)
}
func GetStrict(alias, key string) (string, error) {
return DefaultUserSettings.GetStrict(alias, key)
}
// Get finds the first value for key within a declaration that matches the
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
// is false and we could not parse the configuration file. Use GetStrict to
// disambiguate the latter cases.
//
// The match for key is case sensitive.
func (u *UserSettings) Get(alias, key string) string {
val, err := u.GetStrict(alias, key)
if err != nil {
return ""
}
return val
}
// Get finds the first value for key within a declaration that matches the
// alias. For more on the pattern syntax, see the manpage for ssh_config.
//
// error will be non-nil if and only if the user's configuration file or the
// system configuration file could not be parsed, and u.IgnoreErrors is true.
func (u *UserSettings) GetStrict(alias, key string) (string, error) {
u.loadConfigs.Do(func() {
// can't parse user file, that's ok.
var filename string
if u.userConfigFinder == nil {
filename = userConfigFinder()
} else {
filename = u.userConfigFinder()
}
var err error
u.userConfig, err = parseFile(filename)
if err != nil && os.IsNotExist(err) == false {
u.onceErr = err
return
}
if u.systemConfigFinder == nil {
filename = systemConfigFinder()
} else {
filename = u.systemConfigFinder()
}
u.systemConfig, err = parseFile(filename)
if err != nil && os.IsNotExist(err) == false {
u.onceErr = err
return
}
})
if u.onceErr != nil && u.IgnoreErrors == false {
return "", u.onceErr
}
val, err := findVal(u.userConfig, alias, key)
if err != nil || val != "" {
return val, err
}
return findVal(u.systemConfig, alias, key)
}
func parseFile(filename string) (*Config, error) {
@@ -45,6 +147,38 @@ func Decode(r io.Reader) (c *Config, err error) {
type Config struct {
position Position
Hosts []*Host
// Set to true to silently ignore errors parsing the config file.
IgnoreErrors bool
}
func (c *Config) Get(alias, key string) (string, error) {
lowerKey := strings.ToLower(key)
for _, host := range c.Hosts {
if !host.Matches(alias) {
continue
}
for _, node := range host.Nodes {
switch t := node.(type) {
case *Empty:
continue
case *KV:
// "keys are case insensitive" per the spec
lkey := strings.ToLower(t.Key)
if lkey == "include" {
panic("can't handle Include directives")
}
if lkey == "match" {
panic("can't handle Match directives")
}
if lkey == lowerKey {
return t.Value, nil
}
default:
return "", fmt.Errorf("unknown Node type %v", t)
}
}
}
return "", nil
}
func (c *Config) String() string {
@@ -55,24 +189,101 @@ func (c *Config) String() string {
return buf.String()
}
type Pattern struct {
str string
regex *regexp.Regexp
}
func (p Pattern) String() string {
return p.str
}
// Copied from regexp.go with * and ? removed.
var specialBytes = []byte(`\.+()|[]{}^$`)
func special(b byte) bool {
return bytes.IndexByte(specialBytes, b) >= 0
}
// NewPattern creates a new Pattern for matching hosts.
func NewPattern(s string) (*Pattern, error) {
// From the manpage:
// A pattern consists of zero or more non-whitespace characters,
// `*' (a wildcard that matches zero or more characters),
// or `?' (a wildcard that matches exactly one character).
// For example, to specify a set of declarations for any host in the
// ".co.uk" set of domains, the following pattern could be used:
//
// Host *.co.uk
//
// The following pattern would match any host in the 192.168.0.[0-9] network range:
//
// Host 192.168.0.?
var buf bytes.Buffer
buf.WriteByte('^')
for i := 0; i < len(s); i++ {
// A byte loop is correct because all metacharacters are ASCII.
switch b := s[i]; b {
case '*':
buf.WriteString(".*")
case '?':
buf.WriteString(".?")
default:
// borrowing from QuoteMeta here.
if special(b) {
buf.WriteByte('\\')
}
buf.WriteByte(b)
}
}
buf.WriteByte('$')
r, err := regexp.Compile(buf.String())
if err != nil {
return nil, err
}
return &Pattern{str: s, regex: r}, nil
}
type Host struct {
// A list of host patterns that should match this host.
Patterns []string
Patterns []*Pattern
// A Node is either a key/value pair or a comment line.
Nodes []Node
// EOLComment is the comment (if any) terminating the Host line.
EOLComment string
hasEquals bool
leadingSpace uint16 // TODO: handle spaces vs tabs here.
// The file starts with an implicit "Host *" declaration.
implicit bool
}
func (h *Host) Matches(alias string) bool {
found := false
for i := range h.Patterns {
if h.Patterns[i].regex.MatchString(alias) {
found = true
break
}
}
return found
}
func (h *Host) String() string {
var buf bytes.Buffer
if h.implicit == false {
buf.WriteString(strings.Repeat(" ", int(h.leadingSpace)))
buf.WriteString("Host ")
buf.WriteString(strings.Join(h.Patterns, " "))
buf.WriteString("Host")
if h.hasEquals {
buf.WriteString(" = ")
} else {
buf.WriteString(" ")
}
for i, pat := range h.Patterns {
buf.WriteString(pat.str)
if i < len(h.Patterns)-1 {
buf.WriteString(" ")
}
}
if h.EOLComment != "" {
buf.WriteString(" #")
buf.WriteString(h.EOLComment)
@@ -80,7 +291,6 @@ func (h *Host) String() string {
buf.WriteByte('\n')
}
for i := range h.Nodes {
//fmt.Printf("%q\n", h.Nodes[i].String())
buf.WriteString(h.Nodes[i].String())
buf.WriteByte('\n')
}
@@ -96,6 +306,7 @@ type KV struct {
Key string
Value string
Comment string
hasEquals bool
leadingSpace uint16 // Space before the key. TODO handle spaces vs tabs.
position Position
}
@@ -108,7 +319,11 @@ func (k *KV) String() string {
if k == nil {
return ""
}
line := fmt.Sprintf("%s%s %s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, k.Value)
equals := " "
if k.hasEquals {
equals = " = "
}
line := fmt.Sprintf("%s%s%s%s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, equals, k.Value)
if k.Comment != "" {
line += " #" + k.Comment
}
@@ -135,10 +350,20 @@ func (e *Empty) String() string {
return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
}
var matchAll *Pattern
func init() {
var err error
matchAll, err = NewPattern("*")
if err != nil {
panic(err)
}
}
func newConfig() *Config {
return &Config{
Hosts: []*Host{
&Host{implicit: true, Patterns: []string{"*"}, Nodes: make([]Node, 0)},
&Host{implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0)},
},
}
}

View File

@@ -29,3 +29,154 @@ func TestDecode(t *testing.T) {
}
}
}
func testConfigFinder(filename string) func() string {
return func() string { return filename }
}
func nullConfigFinder() string {
return ""
}
func TestGet(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val := us.Get("wap", "User")
if val != "root" {
t.Errorf("expected to find User root, got %q", val)
}
}
func TestGetWildcard(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config3"),
}
val := us.Get("bastion.stage.i.us.example.net", "Port")
if val != "22" {
t.Errorf("expected to find Port 22, got %q", val)
}
val = us.Get("bastion.net", "Port")
if val != "25" {
t.Errorf("expected to find Port 24, got %q", val)
}
val = us.Get("10.2.3.4", "Port")
if val != "23" {
t.Errorf("expected to find Port 23, got %q", val)
}
val = us.Get("101.2.3.4", "Port")
if val != "25" {
t.Errorf("expected to find Port 24, got %q", val)
}
val = us.Get("20.20.20.4", "Port")
if val != "24" {
t.Errorf("expected to find Port 24, got %q", val)
}
val = us.Get("20.20.20.20", "Port")
if val != "25" {
t.Errorf("expected to find Port 25, got %q", val)
}
}
func TestGetExtraSpaces(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/extraspace"),
}
val := us.Get("test.test", "Port")
if val != "1234" {
t.Errorf("expected to find Port 1234, got %q", val)
}
}
func TestGetCaseInsensitive(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val := us.Get("wap", "uSER")
if val != "root" {
t.Errorf("expected to find User root, got %q", val)
}
}
func TestGetEmpty(t *testing.T) {
us := &UserSettings{
userConfigFinder: nullConfigFinder,
systemConfigFinder: nullConfigFinder,
}
val, err := us.GetStrict("wap", "User")
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
if val != "" {
t.Errorf("expected to get empty string, got %q", val)
}
}
func TestGetEqsign(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/eqsign"),
}
val := us.Get("test.test", "Port")
if val != "1234" {
t.Errorf("expected to find Port 1234, got %q", val)
}
val = us.Get("test.test", "Port2")
if val != "5678" {
t.Errorf("expected to find Port2 5678, got %q", val)
}
}
func TestInclude(t *testing.T) {
t.Skip("can't currently handle Include directives")
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/include"),
}
val := us.Get("test.test", "Compression")
if val != "yes" {
t.Errorf("expected to find Compression=yes in included file, got %q", val)
}
}
var matchTests = []struct {
in []string
alias string
want bool
}{
{[]string{"*"}, "any.test", true},
{[]string{"a", "b", "*", "c"}, "any.test", true},
{[]string{"a", "b", "c"}, "any.test", false},
{[]string{"any.test"}, "any1test", false},
{[]string{"192.168.0.?"}, "192.168.0.1", true},
{[]string{"192.168.0.?"}, "192.168.0.10", false},
{[]string{"*.co.uk"}, "bbc.co.uk", true},
{[]string{"*.co.uk"}, "subdomain.bbc.co.uk", true},
{[]string{"*.*.co.uk"}, "bbc.co.uk", false},
{[]string{"*.*.co.uk"}, "subdomain.bbc.co.uk", true},
}
func TestMatches(t *testing.T) {
for _, tt := range matchTests {
patterns := make([]*Pattern, len(tt.in), len(tt.in))
for i := range tt.in {
pat, err := NewPattern(tt.in[i])
if err != nil {
t.Fatalf("error compiling pattern %s: %v", tt.in[i], err)
}
patterns[i] = pat
}
host := &Host{
Patterns: patterns,
}
got := host.Matches(tt.alias)
if got != tt.want {
t.Errorf("host(%q).Matches(%q): got %v, want %v", tt.in, tt.alias, got, tt.want)
}
}
}

View File

@@ -36,21 +36,50 @@ func (s *sshLexer) lexComment(previousState sshLexStateFn) sshLexStateFn {
}
}
// lex the space after an equals sign in a function
func (s *sshLexer) lexRspace() sshLexStateFn {
for {
next := s.peek()
if !isSpace(next) {
break
}
s.skip()
}
return s.lexRvalue
}
func (s *sshLexer) lexEquals() sshLexStateFn {
for {
next := s.peek()
if next == '=' {
s.emit(tokenEquals)
s.skip()
return s.lexRspace
}
// TODO error handling here; newline eof etc.
if !isSpace(next) {
break
}
s.skip()
}
return s.lexRvalue
}
func (s *sshLexer) lexKey() sshLexStateFn {
growingString := ""
for r := s.peek(); isKeyChar(r); r = s.peek() {
// simplified a lot here
if isSpace(r) {
if isSpace(r) || r == '=' {
s.emitWithValue(tokenKey, growingString)
s.skip()
return s.lexRvalue
return s.lexEquals
}
growingString += string(r)
s.next()
}
s.emitWithValue(tokenKey, growingString)
return s.lexVoid
return s.lexEquals
}
func (s *sshLexer) lexRvalue() sshLexStateFn {

View File

@@ -75,8 +75,12 @@ func (p *sshParser) parseStart() sshParserStateFn {
func (p *sshParser) parseKV() sshParserStateFn {
key := p.getToken()
p.assume(tokenString)
hasEquals := false
val := p.getToken()
if val.typ == tokenEquals {
hasEquals = true
val = p.getToken()
}
comment := ""
tok := p.peek()
if tok.typ == tokenComment && tok.Position.Line == val.Position.Line {
@@ -84,16 +88,26 @@ func (p *sshParser) parseKV() sshParserStateFn {
comment = tok.val
}
if key.val == "Host" {
patterns := strings.Split(val.val, " ")
for i := range patterns {
if patterns[i] == "" {
patterns = append(patterns[:i], patterns[i+1:]...)
strPatterns := strings.Split(val.val, " ")
for i := range strPatterns {
if strPatterns[i] == "" {
strPatterns = append(strPatterns[:i], strPatterns[i+1:]...)
}
}
patterns := make([]*Pattern, len(strPatterns))
for i := range strPatterns {
pat, err := NewPattern(strPatterns[i])
if err != nil {
p.raiseError(val, "Invalid host pattern: %v", err)
return nil
}
patterns[i] = pat
}
p.config.Hosts = append(p.config.Hosts, &Host{
Patterns: patterns,
Nodes: make([]Node, 0),
EOLComment: comment,
hasEquals: hasEquals,
})
return p.parseStart
}
@@ -102,6 +116,7 @@ func (p *sshParser) parseKV() sshParserStateFn {
Key: key.val,
Value: val.val,
Comment: comment,
hasEquals: hasEquals,
leadingSpace: uint16(key.Position.Col) - 1,
position: key.Position,
}

3
testdata/anotherfile vendored Normal file
View File

@@ -0,0 +1,3 @@
# Not sure that this actually works; Include might need to be relative to the
# load directory.
Compression yes

31
testdata/config3 vendored Normal file
View File

@@ -0,0 +1,31 @@
Host bastion.*.i.*.example.net
User simon.thulbourn
Port 22
ForwardAgent yes
IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa
UseKeychain yes
Host 10.*
User simon.thulbourn
Port 23
ForwardAgent yes
StrictHostKeyChecking no
UserKnownHostsFile /dev/null
IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa
UseKeychain yes
ProxyCommand >&1; h="%h"; exec ssh -q $(ssh-bastion -ip $h) nc %h %p
Host 20.20.20.?
User simon.thulbourn
Port 24
ForwardAgent yes
StrictHostKeyChecking no
UserKnownHostsFile /dev/null
IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa
UseKeychain yes
ProxyCommand >&1; h="%h"; exec ssh -q $(ssh-bastion -ip $h) nc %h %p
Host *
IdentityFile /Users/%u/.ssh/%h/%r/id_rsa
UseKeychain yes
Port 25

4
testdata/eqsign vendored Normal file
View File

@@ -0,0 +1,4 @@
Host=test.test
Port =1234
Port2= 5678
Compression yes

2
testdata/extraspace vendored Normal file
View File

@@ -0,0 +1,2 @@
Host test.test
Port 1234

2
testdata/include vendored Normal file
View File

@@ -0,0 +1,2 @@
Host test.test
Include anotherfile

View File

@@ -28,6 +28,7 @@ const (
tokenEmptyLine
tokenComment
tokenKey
tokenEquals
tokenString
)