From 67c39ca6b47d468e90f304db1a51c130d349269d Mon Sep 17 00:00:00 2001 From: Kevin Burke Date: Sun, 23 Apr 2017 11:42:22 -0700 Subject: [PATCH] Implement Get and wildcard match Lots of changes and new API's here. Fixes #7. --- Makefile | 8 +- README.md | 23 ++-- config.go | 243 +++++++++++++++++++++++++++++++++++++++++-- config_test.go | 151 +++++++++++++++++++++++++++ lexer.go | 35 ++++++- parser.go | 25 ++++- testdata/anotherfile | 3 + testdata/config3 | 31 ++++++ testdata/eqsign | 4 + testdata/extraspace | 2 + testdata/include | 2 + token.go | 1 + 12 files changed, 502 insertions(+), 26 deletions(-) create mode 100644 testdata/anotherfile create mode 100644 testdata/config3 create mode 100644 testdata/eqsign create mode 100644 testdata/extraspace create mode 100644 testdata/include diff --git a/Makefile b/Makefile index 9db76f3..0bd919b 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,7 @@ +lint: + go vet ./... + staticcheck ./... -test: - go test -timeout=10ms ./... +test: lint + @# the timeout helps guard against infinite recursion + go test -timeout=30ms ./... diff --git a/README.md b/README.md index 2cade6a..dd8ce78 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ # ssh_config -This is a Go parser for `ssh_config` files. Importantly, this parser attempts to -preserve comments, so you can manipulate a `ssh_config` file from a program, if -your heart wishes. +This is a Go parser for `ssh_config` files. Importantly, this parser attempts +to preserve comments in a given file, so you can manipulate a `ssh_config` file +from a program, if your heart desires. Example usage: ```go f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) -cfg, _ := ssh_config.LoadReader(f) +cfg, _ := ssh_config.Decode(f) for _, host := range cfg.Hosts { fmt.Println("patterns:", host.Patterns) for _, node := range host.Nodes { @@ -20,6 +20,15 @@ for _, host := range cfg.Hosts { fmt.Println(cfg.String()) ``` -This is very alpha software. In particular the most useful thing you want to do -with this is figure out the correct HostName and User for a particular host -pattern. There's no way to do this, currently. +The `ssh_config` program will attempt to read values from `$HOME/.ssh/config`, +falling back to `/etc/ssh/ssh_config`. + +```go +port := ssh_config.Get("myhost", "Port") +``` + +## Donating + +Donations free up time to make improvements to the library, and respond to +bug reports. You can send donations via Paypal's "Send Money" feature to +kev@inburke.com. Donations are not tax deductible in the USA. diff --git a/config.go b/config.go index 251065a..fdd4385 100644 --- a/config.go +++ b/config.go @@ -6,14 +6,116 @@ import ( "fmt" "io" "os" + osuser "os/user" + "path/filepath" + "regexp" "runtime" "strings" + "sync" ) +type configFinder func() string + type UserSettings struct { - systemConfig *Config - userConfig *Config - username string + systemConfig *Config + systemConfigFinder configFinder + userConfig *Config + userConfigFinder configFinder + username string + loadConfigs sync.Once + onceErr error + IgnoreErrors bool +} + +func userConfigFinder() string { + user, err := osuser.Current() + var home string + if err == nil { + home = user.HomeDir + } else { + home = os.Getenv("HOME") + } + return filepath.Join(home, ".ssh", "config") +} + +var DefaultUserSettings = &UserSettings{ + IgnoreErrors: false, + systemConfigFinder: systemConfigFinder, + userConfigFinder: userConfigFinder, +} + +func systemConfigFinder() string { + return filepath.Join("/", "etc", "ssh", "ssh_config") +} + +func findVal(c *Config, alias, key string) (string, error) { + if c == nil { + return "", nil + } + return c.Get(alias, key) +} + +func Get(alias, key string) string { + return DefaultUserSettings.Get(alias, key) +} + +func GetStrict(alias, key string) (string, error) { + return DefaultUserSettings.GetStrict(alias, key) +} + +// Get finds the first value for key within a declaration that matches the +// alias. Get returns the empty string if no value was found, or if IgnoreErrors +// is false and we could not parse the configuration file. Use GetStrict to +// disambiguate the latter cases. +// +// The match for key is case sensitive. +func (u *UserSettings) Get(alias, key string) string { + val, err := u.GetStrict(alias, key) + if err != nil { + return "" + } + return val +} + +// Get finds the first value for key within a declaration that matches the +// alias. For more on the pattern syntax, see the manpage for ssh_config. +// +// error will be non-nil if and only if the user's configuration file or the +// system configuration file could not be parsed, and u.IgnoreErrors is true. +func (u *UserSettings) GetStrict(alias, key string) (string, error) { + u.loadConfigs.Do(func() { + // can't parse user file, that's ok. + var filename string + if u.userConfigFinder == nil { + filename = userConfigFinder() + } else { + filename = u.userConfigFinder() + } + var err error + u.userConfig, err = parseFile(filename) + if err != nil && os.IsNotExist(err) == false { + u.onceErr = err + return + } + if u.systemConfigFinder == nil { + filename = systemConfigFinder() + } else { + filename = u.systemConfigFinder() + } + u.systemConfig, err = parseFile(filename) + if err != nil && os.IsNotExist(err) == false { + u.onceErr = err + return + } + }) + if u.onceErr != nil && u.IgnoreErrors == false { + return "", u.onceErr + } + val, err := findVal(u.userConfig, alias, key) + if err != nil || val != "" { + return val, err + } + return findVal(u.systemConfig, alias, key) } func parseFile(filename string) (*Config, error) { @@ -45,6 +147,38 @@ func Decode(r io.Reader) (c *Config, err error) { type Config struct { position Position Hosts []*Host + // Set to true to silently ignore errors parsing the config file. + IgnoreErrors bool +} + +func (c *Config) Get(alias, key string) (string, error) { + lowerKey := strings.ToLower(key) + for _, host := range c.Hosts { + if !host.Matches(alias) { + continue + } + for _, node := range host.Nodes { + switch t := node.(type) { + case *Empty: + continue + case *KV: + // "keys are case insensitive" per the spec + lkey := strings.ToLower(t.Key) + if lkey == "include" { + panic("can't handle Include directives") + } + if lkey == "match" { + panic("can't handle Match directives") + } + if lkey == lowerKey { + return t.Value, nil + } + default: + return "", fmt.Errorf("unknown Node type %v", t) + } + } + } + return "", nil } func (c *Config) String() string { @@ -55,24 +189,101 @@ func (c *Config) String() string { return buf.String() } +type Pattern struct { + str string + regex *regexp.Regexp +} + +func (p Pattern) String() string { + return p.str +} + +// Copied from regexp.go with * and ? removed. +var specialBytes = []byte(`\.+()|[]{}^$`) + +func special(b byte) bool { + return bytes.IndexByte(specialBytes, b) >= 0 +} + +// NewPattern creates a new Pattern for matching hosts. +func NewPattern(s string) (*Pattern, error) { + // From the manpage: + // A pattern consists of zero or more non-whitespace characters, + // `*' (a wildcard that matches zero or more characters), + // or `?' (a wildcard that matches exactly one character). + // For example, to specify a set of declarations for any host in the + // ".co.uk" set of domains, the following pattern could be used: + // + // Host *.co.uk + // + // The following pattern would match any host in the 192.168.0.[0-9] network range: + // + // Host 192.168.0.? + var buf bytes.Buffer + buf.WriteByte('^') + for i := 0; i < len(s); i++ { + // A byte loop is correct because all metacharacters are ASCII. + switch b := s[i]; b { + case '*': + buf.WriteString(".*") + case '?': + buf.WriteString(".?") + default: + // borrowing from QuoteMeta here. + if special(b) { + buf.WriteByte('\\') + } + buf.WriteByte(b) + } + } + buf.WriteByte('$') + r, err := regexp.Compile(buf.String()) + if err != nil { + return nil, err + } + return &Pattern{str: s, regex: r}, nil +} + type Host struct { // A list of host patterns that should match this host. - Patterns []string + Patterns []*Pattern // A Node is either a key/value pair or a comment line. Nodes []Node // EOLComment is the comment (if any) terminating the Host line. EOLComment string + hasEquals bool leadingSpace uint16 // TODO: handle spaces vs tabs here. // The file starts with an implicit "Host *" declaration. implicit bool } +func (h *Host) Matches(alias string) bool { + found := false + for i := range h.Patterns { + if h.Patterns[i].regex.MatchString(alias) { + found = true + break + } + } + return found +} + func (h *Host) String() string { var buf bytes.Buffer if h.implicit == false { buf.WriteString(strings.Repeat(" ", int(h.leadingSpace))) - buf.WriteString("Host ") - buf.WriteString(strings.Join(h.Patterns, " ")) + buf.WriteString("Host") + if h.hasEquals { + buf.WriteString(" = ") + } else { + buf.WriteString(" ") + } + for i, pat := range h.Patterns { + buf.WriteString(pat.str) + if i < len(h.Patterns)-1 { + buf.WriteString(" ") + } + } if h.EOLComment != "" { buf.WriteString(" #") buf.WriteString(h.EOLComment) @@ -80,7 +291,6 @@ func (h *Host) String() string { buf.WriteByte('\n') } for i := range h.Nodes { - //fmt.Printf("%q\n", h.Nodes[i].String()) buf.WriteString(h.Nodes[i].String()) buf.WriteByte('\n') } @@ -96,6 +306,7 @@ type KV struct { Key string Value string Comment string + hasEquals bool leadingSpace uint16 // Space before the key. TODO handle spaces vs tabs. position Position } @@ -108,7 +319,11 @@ func (k *KV) String() string { if k == nil { return "" } - line := fmt.Sprintf("%s%s %s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, k.Value) + equals := " " + if k.hasEquals { + equals = " = " + } + line := fmt.Sprintf("%s%s%s%s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, equals, k.Value) if k.Comment != "" { line += " #" + k.Comment } @@ -135,10 +350,20 @@ func (e *Empty) String() string { return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment) } +var matchAll *Pattern + +func init() { + var err error + matchAll, err = NewPattern("*") + if err != nil { + panic(err) + } +} + func newConfig() *Config { return &Config{ Hosts: []*Host{ - &Host{implicit: true, Patterns: []string{"*"}, Nodes: make([]Node, 0)}, + &Host{implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0)}, }, } } diff --git a/config_test.go b/config_test.go index 8b97fdf..5213d4a 100644 --- a/config_test.go +++ b/config_test.go @@ -29,3 +29,154 @@ func TestDecode(t *testing.T) { } } } + +func testConfigFinder(filename string) func() string { + return func() string { return filename } +} + +func nullConfigFinder() string { + return "" +} + +func TestGet(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val := us.Get("wap", "User") + if val != "root" { + t.Errorf("expected to find User root, got %q", val) + } +} + +func TestGetWildcard(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config3"), + } + + val := us.Get("bastion.stage.i.us.example.net", "Port") + if val != "22" { + t.Errorf("expected to find Port 22, got %q", val) + } + + val = us.Get("bastion.net", "Port") + if val != "25" { + t.Errorf("expected to find Port 24, got %q", val) + } + + val = us.Get("10.2.3.4", "Port") + if val != "23" { + t.Errorf("expected to find Port 23, got %q", val) + } + val = us.Get("101.2.3.4", "Port") + if val != "25" { + t.Errorf("expected to find Port 24, got %q", val) + } + val = us.Get("20.20.20.4", "Port") + if val != "24" { + t.Errorf("expected to find Port 24, got %q", val) + } + val = us.Get("20.20.20.20", "Port") + if val != "25" { + t.Errorf("expected to find Port 25, got %q", val) + } +} + +func TestGetExtraSpaces(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/extraspace"), + } + + val := us.Get("test.test", "Port") + if val != "1234" { + t.Errorf("expected to find Port 1234, got %q", val) + } +} + +func TestGetCaseInsensitive(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val := us.Get("wap", "uSER") + if val != "root" { + t.Errorf("expected to find User root, got %q", val) + } +} + +func TestGetEmpty(t *testing.T) { + us := &UserSettings{ + userConfigFinder: nullConfigFinder, + systemConfigFinder: nullConfigFinder, + } + val, err := us.GetStrict("wap", "User") + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if val != "" { + t.Errorf("expected to get empty string, got %q", val) + } +} + +func TestGetEqsign(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/eqsign"), + } + + val := us.Get("test.test", "Port") + if val != "1234" { + t.Errorf("expected to find Port 1234, got %q", val) + } + val = us.Get("test.test", "Port2") + if val != "5678" { + t.Errorf("expected to find Port2 5678, got %q", val) + } +} + +func TestInclude(t *testing.T) { + t.Skip("can't currently handle Include directives") + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/include"), + } + val := us.Get("test.test", "Compression") + if val != "yes" { + t.Errorf("expected to find Compression=yes in included file, got %q", val) + } +} + +var matchTests = []struct { + in []string + alias string + want bool +}{ + {[]string{"*"}, "any.test", true}, + {[]string{"a", "b", "*", "c"}, "any.test", true}, + {[]string{"a", "b", "c"}, "any.test", false}, + {[]string{"any.test"}, "any1test", false}, + {[]string{"192.168.0.?"}, "192.168.0.1", true}, + {[]string{"192.168.0.?"}, "192.168.0.10", false}, + {[]string{"*.co.uk"}, "bbc.co.uk", true}, + {[]string{"*.co.uk"}, "subdomain.bbc.co.uk", true}, + {[]string{"*.*.co.uk"}, "bbc.co.uk", false}, + {[]string{"*.*.co.uk"}, "subdomain.bbc.co.uk", true}, +} + +func TestMatches(t *testing.T) { + for _, tt := range matchTests { + patterns := make([]*Pattern, len(tt.in), len(tt.in)) + for i := range tt.in { + pat, err := NewPattern(tt.in[i]) + if err != nil { + t.Fatalf("error compiling pattern %s: %v", tt.in[i], err) + } + patterns[i] = pat + } + host := &Host{ + Patterns: patterns, + } + got := host.Matches(tt.alias) + if got != tt.want { + t.Errorf("host(%q).Matches(%q): got %v, want %v", tt.in, tt.alias, got, tt.want) + } + } +} diff --git a/lexer.go b/lexer.go index 7cabed9..27bed12 100644 --- a/lexer.go +++ b/lexer.go @@ -36,21 +36,50 @@ func (s *sshLexer) lexComment(previousState sshLexStateFn) sshLexStateFn { } } +// lex the space after an equals sign in a function +func (s *sshLexer) lexRspace() sshLexStateFn { + for { + next := s.peek() + if !isSpace(next) { + break + } + s.skip() + } + return s.lexRvalue +} + +func (s *sshLexer) lexEquals() sshLexStateFn { + for { + next := s.peek() + if next == '=' { + s.emit(tokenEquals) + s.skip() + return s.lexRspace + } + // TODO error handling here; newline eof etc. + if !isSpace(next) { + break + } + s.skip() + } + return s.lexRvalue +} + func (s *sshLexer) lexKey() sshLexStateFn { growingString := "" for r := s.peek(); isKeyChar(r); r = s.peek() { // simplified a lot here - if isSpace(r) { + if isSpace(r) || r == '=' { s.emitWithValue(tokenKey, growingString) s.skip() - return s.lexRvalue + return s.lexEquals } growingString += string(r) s.next() } s.emitWithValue(tokenKey, growingString) - return s.lexVoid + return s.lexEquals } func (s *sshLexer) lexRvalue() sshLexStateFn { diff --git a/parser.go b/parser.go index e0ff46e..8332f81 100644 --- a/parser.go +++ b/parser.go @@ -75,8 +75,12 @@ func (p *sshParser) parseStart() sshParserStateFn { func (p *sshParser) parseKV() sshParserStateFn { key := p.getToken() - p.assume(tokenString) + hasEquals := false val := p.getToken() + if val.typ == tokenEquals { + hasEquals = true + val = p.getToken() + } comment := "" tok := p.peek() if tok.typ == tokenComment && tok.Position.Line == val.Position.Line { @@ -84,16 +88,26 @@ func (p *sshParser) parseKV() sshParserStateFn { comment = tok.val } if key.val == "Host" { - patterns := strings.Split(val.val, " ") - for i := range patterns { - if patterns[i] == "" { - patterns = append(patterns[:i], patterns[i+1:]...) + strPatterns := strings.Split(val.val, " ") + for i := range strPatterns { + if strPatterns[i] == "" { + strPatterns = append(strPatterns[:i], strPatterns[i+1:]...) } } + patterns := make([]*Pattern, len(strPatterns)) + for i := range strPatterns { + pat, err := NewPattern(strPatterns[i]) + if err != nil { + p.raiseError(val, "Invalid host pattern: %v", err) + return nil + } + patterns[i] = pat + } p.config.Hosts = append(p.config.Hosts, &Host{ Patterns: patterns, Nodes: make([]Node, 0), EOLComment: comment, + hasEquals: hasEquals, }) return p.parseStart } @@ -102,6 +116,7 @@ func (p *sshParser) parseKV() sshParserStateFn { Key: key.val, Value: val.val, Comment: comment, + hasEquals: hasEquals, leadingSpace: uint16(key.Position.Col) - 1, position: key.Position, } diff --git a/testdata/anotherfile b/testdata/anotherfile new file mode 100644 index 0000000..c4de676 --- /dev/null +++ b/testdata/anotherfile @@ -0,0 +1,3 @@ +# Not sure that this actually works; Include might need to be relative to the +# load directory. +Compression yes diff --git a/testdata/config3 b/testdata/config3 new file mode 100644 index 0000000..8c15654 --- /dev/null +++ b/testdata/config3 @@ -0,0 +1,31 @@ +Host bastion.*.i.*.example.net + User simon.thulbourn + Port 22 + ForwardAgent yes + IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa + UseKeychain yes + +Host 10.* + User simon.thulbourn + Port 23 + ForwardAgent yes + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa + UseKeychain yes + ProxyCommand >&1; h="%h"; exec ssh -q $(ssh-bastion -ip $h) nc %h %p + +Host 20.20.20.? + User simon.thulbourn + Port 24 + ForwardAgent yes + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa + UseKeychain yes + ProxyCommand >&1; h="%h"; exec ssh -q $(ssh-bastion -ip $h) nc %h %p + +Host * + IdentityFile /Users/%u/.ssh/%h/%r/id_rsa + UseKeychain yes + Port 25 diff --git a/testdata/eqsign b/testdata/eqsign new file mode 100644 index 0000000..6332b85 --- /dev/null +++ b/testdata/eqsign @@ -0,0 +1,4 @@ +Host=test.test + Port =1234 + Port2= 5678 + Compression yes diff --git a/testdata/extraspace b/testdata/extraspace new file mode 100644 index 0000000..e9ce2f8 --- /dev/null +++ b/testdata/extraspace @@ -0,0 +1,2 @@ +Host test.test + Port 1234 diff --git a/testdata/include b/testdata/include new file mode 100644 index 0000000..ff8ba51 --- /dev/null +++ b/testdata/include @@ -0,0 +1,2 @@ +Host test.test + Include anotherfile diff --git a/token.go b/token.go index 5964c88..a0ecbb2 100644 --- a/token.go +++ b/token.go @@ -28,6 +28,7 @@ const ( tokenEmptyLine tokenComment tokenKey + tokenEquals tokenString )