Implement Include directive

It's tricky because it involves recursive filesystem parsing, depth
checking and glob matching. But figured it out.

Fixes #4.
This commit is contained in:
Kevin Burke
2017-04-24 11:06:33 -07:00
parent ad36f0d71a
commit 54fabb9a37
6 changed files with 232 additions and 26 deletions

139
config.go
View File

@@ -52,15 +52,17 @@ type UserSettings struct {
IgnoreErrors bool IgnoreErrors bool
} }
func userConfigFinder() string { func homedir() string {
user, err := osuser.Current() user, err := osuser.Current()
var home string
if err == nil { if err == nil {
home = user.HomeDir return user.HomeDir
} else { } else {
home = os.Getenv("HOME") return os.Getenv("HOME")
} }
return filepath.Join(home, ".ssh", "config") }
func userConfigFinder() string {
return filepath.Join(homedir(), ".ssh", "config")
} }
// DefaultUserSettings is the default UserSettings and is used by Get and // DefaultUserSettings is the default UserSettings and is used by Get and
@@ -164,27 +166,44 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) {
} }
func parseFile(filename string) (*Config, error) { func parseFile(filename string) (*Config, error) {
return parseWithDepth(filename, 0)
}
func parseWithDepth(filename string, depth uint8) (*Config, error) {
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close() defer f.Close()
return Decode(f) return decode(f, isSystem(filename), depth)
}
func isSystem(filename string) bool {
// TODO i'm not sure this is the best way to detect a system repo
return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
} }
// Decode reads r into a Config, or returns an error if r could not be parsed as // Decode reads r into a Config, or returns an error if r could not be parsed as
// an SSH config file. // an SSH config file.
func Decode(r io.Reader) (c *Config, err error) { func Decode(r io.Reader) (*Config, error) {
return decode(r, false, 0)
}
func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok { if _, ok := r.(runtime.Error); ok {
panic(r) panic(r)
} }
if e, ok := r.(error); ok && e == ErrDepthExceeded {
err = e
return
}
err = errors.New(r.(string)) err = errors.New(r.(string))
} }
}() }()
c = parseSSH(lexSSH(r)) c = parseSSH(lexSSH(r), system, depth)
return c, err return c, err
} }
@@ -194,11 +213,12 @@ type Config struct {
// A list of hosts to match against. The file begins with an implicit // A list of hosts to match against. The file begins with an implicit
// "Host *" declaration matching all hosts. // "Host *" declaration matching all hosts.
Hosts []*Host Hosts []*Host
depth uint8
} }
// Get finds the first value in the configuration that matches the alias and // Get finds the first value in the configuration that matches the alias and
// contains key. Get returns the empty string if no value was found, the Config // contains key. Get returns the empty string if no value was found, or if the
// contains an invalid conditional Include value. // Config contains an invalid conditional Include value.
// //
// The match for key is case insensitive. // The match for key is case insensitive.
// //
@@ -216,15 +236,17 @@ func (c *Config) Get(alias, key string) (string, error) {
case *KV: case *KV:
// "keys are case insensitive" per the spec // "keys are case insensitive" per the spec
lkey := strings.ToLower(t.Key) lkey := strings.ToLower(t.Key)
if lkey == "include" {
panic("can't handle Include directives")
}
if lkey == "match" { if lkey == "match" {
panic("can't handle Match directives") panic("can't handle Match directives")
} }
if lkey == lowerKey { if lkey == lowerKey {
return t.Value, nil return t.Value, nil
} }
case *Include:
val := t.Get(alias, key)
if val != "" {
return val, nil
}
default: default:
return "", fmt.Errorf("unknown Node type %v", t) return "", fmt.Errorf("unknown Node type %v", t)
} }
@@ -437,6 +459,90 @@ func (e *Empty) String() string {
return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment) return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
} }
type Include struct {
Comment string
parsed bool
// an include directive can include several different files, and wildcards
directives []string
// actual filenames are listed here
files map[string]*Config
leadingSpace uint16
position Position
depth uint8
}
const maxRecurseDepth = 5
var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded")
// NewInclude creates a new Include with a list of files to include.
func NewInclude(directives []string, comment string, system bool, depth uint8) (*Include, error) {
if depth > maxRecurseDepth {
return nil, ErrDepthExceeded
}
inc := &Include{
Comment: comment,
directives: directives,
files: make(map[string]*Config),
depth: depth,
}
for i := range directives {
var path string
if filepath.IsAbs(directives[i]) {
path = directives[i]
} else if system {
path = filepath.Join("/etc/ssh", directives[i])
} else {
path = filepath.Join(homedir(), ".ssh", directives[i])
}
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
}
for j := range matches {
if _, ok := inc.files[matches[j]]; ok {
// config already parsed
continue
}
config, err := parseWithDepth(matches[j], depth)
if err != nil {
return nil, err
}
inc.files[matches[j]] = config
}
}
// check depth limit
// for each directive:
// for each file in the list:
// - increment the depth array
// - parse it, return any errors
// - add to files map
return inc, nil
}
func (i *Include) Pos() Position {
return i.position
}
func (i *Include) Add() error {
return nil
}
// Get finds the first value in the Include statement
func (i *Include) Get(alias, key string) string {
for _, cfg := range i.files {
val, err := cfg.Get(alias, key)
if err == nil && val != "" {
return val
}
}
return ""
}
func (i *Include) String() string {
return "TODO"
}
var matchAll *Pattern var matchAll *Pattern
func init() { func init() {
@@ -450,7 +556,12 @@ func init() {
func newConfig() *Config { func newConfig() *Config {
return &Config{ return &Config{
Hosts: []*Host{ Hosts: []*Host{
&Host{implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0)}, &Host{
implicit: true,
Patterns: []*Pattern{matchAll},
Nodes: make([]Node, 0),
},
}, },
depth: 0,
} }
} }

View File

@@ -3,6 +3,8 @@ package ssh_config
import ( import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"os"
"path/filepath"
"testing" "testing"
) )
@@ -133,14 +135,75 @@ func TestGetEqsign(t *testing.T) {
} }
} }
var includeFile = []byte(`
# This host should not exist, so we can use it for test purposes / it won't
# interfere with any other configurations.
Host kevinburke.ssh_config.test.example.com
Port 4567
`)
func TestInclude(t *testing.T) { func TestInclude(t *testing.T) {
t.Skip("can't currently handle Include directives") if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file")
err := ioutil.WriteFile(testPath, includeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{ us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/include"), userConfigFinder: testConfigFinder("testdata/include"),
} }
val := us.Get("test.test", "Compression") val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
if val != "yes" { if val != "4567" {
t.Errorf("expected to find Compression=yes in included file, got %q", val) t.Errorf("expected to find Port=4567 in included file, got %q", val)
}
}
func TestIncludeSystem(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file")
err := ioutil.WriteFile(testPath, includeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{
systemConfigFinder: testConfigFinder("testdata/include"),
}
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
if val != "4567" {
t.Errorf("expected to find Port=4567 in included file, got %q", val)
}
}
var recursiveIncludeFile = []byte(`
Host kevinburke.ssh_config.test.example.com
Include kevinburke-ssh-config-recursive-include
`)
func TestIncludeRecursive(t *testing.T) {
if testing.Short() {
t.Skip("skipping fs write in short mode")
}
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include")
err := ioutil.WriteFile(testPath, recursiveIncludeFile, 0644)
if err != nil {
t.Skipf("couldn't write SSH config file: %v", err.Error())
}
defer os.Remove(testPath)
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/include-recursive"),
}
val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port")
if err != ErrDepthExceeded {
t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err)
}
if val != "" {
t.Errorf("non-empty string value %s", val)
} }
} }

View File

@@ -11,15 +11,26 @@ type sshParser struct {
tokensBuffer []token tokensBuffer []token
currentTable []string currentTable []string
seenTableKeys []string seenTableKeys []string
// /etc/ssh parser or local parser - used to find the default for relative
// filepaths in the Include directive
system bool
depth uint8
} }
type sshParserStateFn func() sshParserStateFn type sshParserStateFn func() sshParserStateFn
// Formats and panics an error message based on a token // Formats and panics an error message based on a token
func (p *sshParser) raiseError(tok *token, msg string, args ...interface{}) { func (p *sshParser) raiseErrorf(tok *token, msg string, args ...interface{}) {
panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...)) panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...))
} }
func (p *sshParser) raiseError(tok *token, err error) {
if err == ErrDepthExceeded {
panic(err)
}
panic(tok.Position.String() + ": " + err.Error())
}
func (p *sshParser) run() { func (p *sshParser) run() {
for state := p.parseStart; state != nil; { for state := p.parseStart; state != nil; {
state = state() state = state()
@@ -68,7 +79,7 @@ func (p *sshParser) parseStart() sshParserStateFn {
case tokenEOF: case tokenEOF:
return nil return nil
default: default:
p.raiseError(tok, fmt.Sprintf("unexpected token %q\n", tok)) p.raiseErrorf(tok, fmt.Sprintf("unexpected token %q\n", tok))
} }
return nil return nil
} }
@@ -98,7 +109,7 @@ func (p *sshParser) parseKV() sshParserStateFn {
for i := range strPatterns { for i := range strPatterns {
pat, err := NewPattern(strPatterns[i]) pat, err := NewPattern(strPatterns[i])
if err != nil { if err != nil {
p.raiseError(val, "Invalid host pattern: %v", err) p.raiseErrorf(val, "Invalid host pattern: %v", err)
return nil return nil
} }
patterns[i] = pat patterns[i] = pat
@@ -112,6 +123,19 @@ func (p *sshParser) parseKV() sshParserStateFn {
return p.parseStart return p.parseStart
} }
lastHost := p.config.Hosts[len(p.config.Hosts)-1] lastHost := p.config.Hosts[len(p.config.Hosts)-1]
if key.val == "Include" {
inc, err := NewInclude(strings.Split(val.val, " "), comment, p.system, p.depth+1)
if err == ErrDepthExceeded {
p.raiseError(val, err)
return nil
}
if err != nil {
p.raiseErrorf(val, "Error parsing Include directive: %v", err)
return nil
}
lastHost.Nodes = append(lastHost.Nodes, inc)
return p.parseStart
}
kv := &KV{ kv := &KV{
Key: key.val, Key: key.val,
Value: val.val, Value: val.val,
@@ -140,14 +164,14 @@ func (p *sshParser) parseComment() sshParserStateFn {
func (p *sshParser) assume(typ tokenType) { func (p *sshParser) assume(typ tokenType) {
tok := p.peek() tok := p.peek()
if tok == nil { if tok == nil {
p.raiseError(tok, "was expecting token %s, but token stream is empty", tok) p.raiseErrorf(tok, "was expecting token %s, but token stream is empty", tok)
} }
if tok.typ != typ { if tok.typ != typ {
p.raiseError(tok, "was expecting token %s, but got %s instead", typ, tok) p.raiseErrorf(tok, "was expecting token %s, but got %s instead", typ, tok)
} }
} }
func parseSSH(flow chan token) *Config { func parseSSH(flow chan token, system bool, depth uint8) *Config {
result := newConfig() result := newConfig()
result.position = Position{1, 1} result.position = Position{1, 1}
parser := &sshParser{ parser := &sshParser{
@@ -156,6 +180,8 @@ func parseSSH(flow chan token) *Config {
tokensBuffer: make([]token, 0), tokensBuffer: make([]token, 0),
currentTable: make([]string, 0), currentTable: make([]string, 0),
seenTableKeys: make([]string, 0), seenTableKeys: make([]string, 0),
system: system,
depth: depth,
} }
parser.run() parser.run()
return result return result

6
testdata/include vendored
View File

@@ -1,2 +1,4 @@
Host test.test Host kevinburke.ssh_config.test.example.com
Include anotherfile # This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on
# the test.
Include kevinburke-ssh-config-*-file

4
testdata/include-recursive vendored Normal file
View File

@@ -0,0 +1,4 @@
Host kevinburke.ssh_config.test.example.com
# This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on
# the test. It should include itself.
Include kevinburke-ssh-config-recursive-include

0
testdata/system-include vendored Normal file
View File