From 54fabb9a3764ca25f792b23dd5bae5474afa1d6b Mon Sep 17 00:00:00 2001 From: Kevin Burke Date: Mon, 24 Apr 2017 11:06:33 -0700 Subject: [PATCH] Implement Include directive It's tricky because it involves recursive filesystem parsing, depth checking and glob matching. But figured it out. Fixes #4. --- config.go | 139 +++++++++++++++++++++++++++++++++---- config_test.go | 71 +++++++++++++++++-- parser.go | 38 ++++++++-- testdata/include | 6 +- testdata/include-recursive | 4 ++ testdata/system-include | 0 6 files changed, 232 insertions(+), 26 deletions(-) create mode 100644 testdata/include-recursive create mode 100644 testdata/system-include diff --git a/config.go b/config.go index 47602b8..c7120b9 100644 --- a/config.go +++ b/config.go @@ -52,15 +52,17 @@ type UserSettings struct { IgnoreErrors bool } -func userConfigFinder() string { +func homedir() string { user, err := osuser.Current() - var home string if err == nil { - home = user.HomeDir + return user.HomeDir } else { - home = os.Getenv("HOME") + return os.Getenv("HOME") } - return filepath.Join(home, ".ssh", "config") +} + +func userConfigFinder() string { + return filepath.Join(homedir(), ".ssh", "config") } // DefaultUserSettings is the default UserSettings and is used by Get and @@ -164,27 +166,44 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) { } func parseFile(filename string) (*Config, error) { + return parseWithDepth(filename, 0) +} + +func parseWithDepth(filename string, depth uint8) (*Config, error) { f, err := os.Open(filename) if err != nil { return nil, err } defer f.Close() - return Decode(f) + return decode(f, isSystem(filename), depth) +} + +func isSystem(filename string) bool { + // TODO i'm not sure this is the best way to detect a system repo + return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh") } // Decode reads r into a Config, or returns an error if r could not be parsed as // an SSH config file. -func Decode(r io.Reader) (c *Config, err error) { +func Decode(r io.Reader) (*Config, error) { + return decode(r, false, 0) +} + +func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) { defer func() { if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { panic(r) } + if e, ok := r.(error); ok && e == ErrDepthExceeded { + err = e + return + } err = errors.New(r.(string)) } }() - c = parseSSH(lexSSH(r)) + c = parseSSH(lexSSH(r), system, depth) return c, err } @@ -194,11 +213,12 @@ type Config struct { // A list of hosts to match against. The file begins with an implicit // "Host *" declaration matching all hosts. Hosts []*Host + depth uint8 } // Get finds the first value in the configuration that matches the alias and -// contains key. Get returns the empty string if no value was found, the Config -// contains an invalid conditional Include value. +// contains key. Get returns the empty string if no value was found, or if the +// Config contains an invalid conditional Include value. // // The match for key is case insensitive. // @@ -216,15 +236,17 @@ func (c *Config) Get(alias, key string) (string, error) { case *KV: // "keys are case insensitive" per the spec lkey := strings.ToLower(t.Key) - if lkey == "include" { - panic("can't handle Include directives") - } if lkey == "match" { panic("can't handle Match directives") } if lkey == lowerKey { return t.Value, nil } + case *Include: + val := t.Get(alias, key) + if val != "" { + return val, nil + } default: return "", fmt.Errorf("unknown Node type %v", t) } @@ -437,6 +459,90 @@ func (e *Empty) String() string { return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment) } +type Include struct { + Comment string + parsed bool + // an include directive can include several different files, and wildcards + directives []string + // actual filenames are listed here + files map[string]*Config + leadingSpace uint16 + position Position + depth uint8 +} + +const maxRecurseDepth = 5 + +var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded") + +// NewInclude creates a new Include with a list of files to include. +func NewInclude(directives []string, comment string, system bool, depth uint8) (*Include, error) { + if depth > maxRecurseDepth { + return nil, ErrDepthExceeded + } + inc := &Include{ + Comment: comment, + directives: directives, + files: make(map[string]*Config), + depth: depth, + } + for i := range directives { + var path string + if filepath.IsAbs(directives[i]) { + path = directives[i] + } else if system { + path = filepath.Join("/etc/ssh", directives[i]) + } else { + path = filepath.Join(homedir(), ".ssh", directives[i]) + } + matches, err := filepath.Glob(path) + if err != nil { + return nil, err + } + for j := range matches { + if _, ok := inc.files[matches[j]]; ok { + // config already parsed + continue + } + config, err := parseWithDepth(matches[j], depth) + if err != nil { + return nil, err + } + inc.files[matches[j]] = config + } + } + // check depth limit + // for each directive: + // for each file in the list: + // - increment the depth array + // - parse it, return any errors + // - add to files map + return inc, nil +} + +func (i *Include) Pos() Position { + return i.position +} + +func (i *Include) Add() error { + return nil +} + +// Get finds the first value in the Include statement +func (i *Include) Get(alias, key string) string { + for _, cfg := range i.files { + val, err := cfg.Get(alias, key) + if err == nil && val != "" { + return val + } + } + return "" +} + +func (i *Include) String() string { + return "TODO" +} + var matchAll *Pattern func init() { @@ -450,7 +556,12 @@ func init() { func newConfig() *Config { return &Config{ Hosts: []*Host{ - &Host{implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0)}, + &Host{ + implicit: true, + Patterns: []*Pattern{matchAll}, + Nodes: make([]Node, 0), + }, }, + depth: 0, } } diff --git a/config_test.go b/config_test.go index 9f2db4a..8cd1003 100644 --- a/config_test.go +++ b/config_test.go @@ -3,6 +3,8 @@ package ssh_config import ( "bytes" "io/ioutil" + "os" + "path/filepath" "testing" ) @@ -133,14 +135,75 @@ func TestGetEqsign(t *testing.T) { } } +var includeFile = []byte(` +# This host should not exist, so we can use it for test purposes / it won't +# interfere with any other configurations. +Host kevinburke.ssh_config.test.example.com + Port 4567 +`) + func TestInclude(t *testing.T) { - t.Skip("can't currently handle Include directives") + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file") + err := ioutil.WriteFile(testPath, includeFile, 0644) + if err != nil { + t.Skipf("couldn't write SSH config file: %v", err.Error()) + } + defer os.Remove(testPath) us := &UserSettings{ userConfigFinder: testConfigFinder("testdata/include"), } - val := us.Get("test.test", "Compression") - if val != "yes" { - t.Errorf("expected to find Compression=yes in included file, got %q", val) + val := us.Get("kevinburke.ssh_config.test.example.com", "Port") + if val != "4567" { + t.Errorf("expected to find Port=4567 in included file, got %q", val) + } +} + +func TestIncludeSystem(t *testing.T) { + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file") + err := ioutil.WriteFile(testPath, includeFile, 0644) + if err != nil { + t.Skipf("couldn't write SSH config file: %v", err.Error()) + } + defer os.Remove(testPath) + us := &UserSettings{ + systemConfigFinder: testConfigFinder("testdata/include"), + } + val := us.Get("kevinburke.ssh_config.test.example.com", "Port") + if val != "4567" { + t.Errorf("expected to find Port=4567 in included file, got %q", val) + } +} + +var recursiveIncludeFile = []byte(` +Host kevinburke.ssh_config.test.example.com + Include kevinburke-ssh-config-recursive-include +`) + +func TestIncludeRecursive(t *testing.T) { + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include") + err := ioutil.WriteFile(testPath, recursiveIncludeFile, 0644) + if err != nil { + t.Skipf("couldn't write SSH config file: %v", err.Error()) + } + defer os.Remove(testPath) + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/include-recursive"), + } + val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port") + if err != ErrDepthExceeded { + t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err) + } + if val != "" { + t.Errorf("non-empty string value %s", val) } } diff --git a/parser.go b/parser.go index 8332f81..959b7bd 100644 --- a/parser.go +++ b/parser.go @@ -11,15 +11,26 @@ type sshParser struct { tokensBuffer []token currentTable []string seenTableKeys []string + // /etc/ssh parser or local parser - used to find the default for relative + // filepaths in the Include directive + system bool + depth uint8 } type sshParserStateFn func() sshParserStateFn // Formats and panics an error message based on a token -func (p *sshParser) raiseError(tok *token, msg string, args ...interface{}) { +func (p *sshParser) raiseErrorf(tok *token, msg string, args ...interface{}) { panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...)) } +func (p *sshParser) raiseError(tok *token, err error) { + if err == ErrDepthExceeded { + panic(err) + } + panic(tok.Position.String() + ": " + err.Error()) +} + func (p *sshParser) run() { for state := p.parseStart; state != nil; { state = state() @@ -68,7 +79,7 @@ func (p *sshParser) parseStart() sshParserStateFn { case tokenEOF: return nil default: - p.raiseError(tok, fmt.Sprintf("unexpected token %q\n", tok)) + p.raiseErrorf(tok, fmt.Sprintf("unexpected token %q\n", tok)) } return nil } @@ -98,7 +109,7 @@ func (p *sshParser) parseKV() sshParserStateFn { for i := range strPatterns { pat, err := NewPattern(strPatterns[i]) if err != nil { - p.raiseError(val, "Invalid host pattern: %v", err) + p.raiseErrorf(val, "Invalid host pattern: %v", err) return nil } patterns[i] = pat @@ -112,6 +123,19 @@ func (p *sshParser) parseKV() sshParserStateFn { return p.parseStart } lastHost := p.config.Hosts[len(p.config.Hosts)-1] + if key.val == "Include" { + inc, err := NewInclude(strings.Split(val.val, " "), comment, p.system, p.depth+1) + if err == ErrDepthExceeded { + p.raiseError(val, err) + return nil + } + if err != nil { + p.raiseErrorf(val, "Error parsing Include directive: %v", err) + return nil + } + lastHost.Nodes = append(lastHost.Nodes, inc) + return p.parseStart + } kv := &KV{ Key: key.val, Value: val.val, @@ -140,14 +164,14 @@ func (p *sshParser) parseComment() sshParserStateFn { func (p *sshParser) assume(typ tokenType) { tok := p.peek() if tok == nil { - p.raiseError(tok, "was expecting token %s, but token stream is empty", tok) + p.raiseErrorf(tok, "was expecting token %s, but token stream is empty", tok) } if tok.typ != typ { - p.raiseError(tok, "was expecting token %s, but got %s instead", typ, tok) + p.raiseErrorf(tok, "was expecting token %s, but got %s instead", typ, tok) } } -func parseSSH(flow chan token) *Config { +func parseSSH(flow chan token, system bool, depth uint8) *Config { result := newConfig() result.position = Position{1, 1} parser := &sshParser{ @@ -156,6 +180,8 @@ func parseSSH(flow chan token) *Config { tokensBuffer: make([]token, 0), currentTable: make([]string, 0), seenTableKeys: make([]string, 0), + system: system, + depth: depth, } parser.run() return result diff --git a/testdata/include b/testdata/include index ff8ba51..0a711ca 100644 --- a/testdata/include +++ b/testdata/include @@ -1,2 +1,4 @@ -Host test.test - Include anotherfile +Host kevinburke.ssh_config.test.example.com + # This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on + # the test. + Include kevinburke-ssh-config-*-file diff --git a/testdata/include-recursive b/testdata/include-recursive new file mode 100644 index 0000000..8a3cd3d --- /dev/null +++ b/testdata/include-recursive @@ -0,0 +1,4 @@ +Host kevinburke.ssh_config.test.example.com + # This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on + # the test. It should include itself. + Include kevinburke-ssh-config-recursive-include diff --git a/testdata/system-include b/testdata/system-include new file mode 100644 index 0000000..e69de29