Implement Include directive
It's tricky because it involves recursive filesystem parsing, depth checking and glob matching. But figured it out. Fixes #4.
This commit is contained in:
139
config.go
139
config.go
@@ -52,15 +52,17 @@ type UserSettings struct {
|
||||
IgnoreErrors bool
|
||||
}
|
||||
|
||||
func userConfigFinder() string {
|
||||
func homedir() string {
|
||||
user, err := osuser.Current()
|
||||
var home string
|
||||
if err == nil {
|
||||
home = user.HomeDir
|
||||
return user.HomeDir
|
||||
} else {
|
||||
home = os.Getenv("HOME")
|
||||
return os.Getenv("HOME")
|
||||
}
|
||||
return filepath.Join(home, ".ssh", "config")
|
||||
}
|
||||
|
||||
func userConfigFinder() string {
|
||||
return filepath.Join(homedir(), ".ssh", "config")
|
||||
}
|
||||
|
||||
// DefaultUserSettings is the default UserSettings and is used by Get and
|
||||
@@ -164,27 +166,44 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) {
|
||||
}
|
||||
|
||||
func parseFile(filename string) (*Config, error) {
|
||||
return parseWithDepth(filename, 0)
|
||||
}
|
||||
|
||||
func parseWithDepth(filename string, depth uint8) (*Config, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
return Decode(f)
|
||||
return decode(f, isSystem(filename), depth)
|
||||
}
|
||||
|
||||
func isSystem(filename string) bool {
|
||||
// TODO i'm not sure this is the best way to detect a system repo
|
||||
return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
|
||||
}
|
||||
|
||||
// Decode reads r into a Config, or returns an error if r could not be parsed as
|
||||
// an SSH config file.
|
||||
func Decode(r io.Reader) (c *Config, err error) {
|
||||
func Decode(r io.Reader) (*Config, error) {
|
||||
return decode(r, false, 0)
|
||||
}
|
||||
|
||||
func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if e, ok := r.(error); ok && e == ErrDepthExceeded {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
err = errors.New(r.(string))
|
||||
}
|
||||
}()
|
||||
|
||||
c = parseSSH(lexSSH(r))
|
||||
c = parseSSH(lexSSH(r), system, depth)
|
||||
return c, err
|
||||
}
|
||||
|
||||
@@ -194,11 +213,12 @@ type Config struct {
|
||||
// A list of hosts to match against. The file begins with an implicit
|
||||
// "Host *" declaration matching all hosts.
|
||||
Hosts []*Host
|
||||
depth uint8
|
||||
}
|
||||
|
||||
// Get finds the first value in the configuration that matches the alias and
|
||||
// contains key. Get returns the empty string if no value was found, the Config
|
||||
// contains an invalid conditional Include value.
|
||||
// contains key. Get returns the empty string if no value was found, or if the
|
||||
// Config contains an invalid conditional Include value.
|
||||
//
|
||||
// The match for key is case insensitive.
|
||||
//
|
||||
@@ -216,15 +236,17 @@ func (c *Config) Get(alias, key string) (string, error) {
|
||||
case *KV:
|
||||
// "keys are case insensitive" per the spec
|
||||
lkey := strings.ToLower(t.Key)
|
||||
if lkey == "include" {
|
||||
panic("can't handle Include directives")
|
||||
}
|
||||
if lkey == "match" {
|
||||
panic("can't handle Match directives")
|
||||
}
|
||||
if lkey == lowerKey {
|
||||
return t.Value, nil
|
||||
}
|
||||
case *Include:
|
||||
val := t.Get(alias, key)
|
||||
if val != "" {
|
||||
return val, nil
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("unknown Node type %v", t)
|
||||
}
|
||||
@@ -437,6 +459,90 @@ func (e *Empty) String() string {
|
||||
return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
|
||||
}
|
||||
|
||||
type Include struct {
|
||||
Comment string
|
||||
parsed bool
|
||||
// an include directive can include several different files, and wildcards
|
||||
directives []string
|
||||
// actual filenames are listed here
|
||||
files map[string]*Config
|
||||
leadingSpace uint16
|
||||
position Position
|
||||
depth uint8
|
||||
}
|
||||
|
||||
const maxRecurseDepth = 5
|
||||
|
||||
var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded")
|
||||
|
||||
// NewInclude creates a new Include with a list of files to include.
|
||||
func NewInclude(directives []string, comment string, system bool, depth uint8) (*Include, error) {
|
||||
if depth > maxRecurseDepth {
|
||||
return nil, ErrDepthExceeded
|
||||
}
|
||||
inc := &Include{
|
||||
Comment: comment,
|
||||
directives: directives,
|
||||
files: make(map[string]*Config),
|
||||
depth: depth,
|
||||
}
|
||||
for i := range directives {
|
||||
var path string
|
||||
if filepath.IsAbs(directives[i]) {
|
||||
path = directives[i]
|
||||
} else if system {
|
||||
path = filepath.Join("/etc/ssh", directives[i])
|
||||
} else {
|
||||
path = filepath.Join(homedir(), ".ssh", directives[i])
|
||||
}
|
||||
matches, err := filepath.Glob(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for j := range matches {
|
||||
if _, ok := inc.files[matches[j]]; ok {
|
||||
// config already parsed
|
||||
continue
|
||||
}
|
||||
config, err := parseWithDepth(matches[j], depth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inc.files[matches[j]] = config
|
||||
}
|
||||
}
|
||||
// check depth limit
|
||||
// for each directive:
|
||||
// for each file in the list:
|
||||
// - increment the depth array
|
||||
// - parse it, return any errors
|
||||
// - add to files map
|
||||
return inc, nil
|
||||
}
|
||||
|
||||
func (i *Include) Pos() Position {
|
||||
return i.position
|
||||
}
|
||||
|
||||
func (i *Include) Add() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get finds the first value in the Include statement
|
||||
func (i *Include) Get(alias, key string) string {
|
||||
for _, cfg := range i.files {
|
||||
val, err := cfg.Get(alias, key)
|
||||
if err == nil && val != "" {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (i *Include) String() string {
|
||||
return "TODO"
|
||||
}
|
||||
|
||||
var matchAll *Pattern
|
||||
|
||||
func init() {
|
||||
@@ -450,7 +556,12 @@ func init() {
|
||||
func newConfig() *Config {
|
||||
return &Config{
|
||||
Hosts: []*Host{
|
||||
&Host{implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0)},
|
||||
&Host{
|
||||
implicit: true,
|
||||
Patterns: []*Pattern{matchAll},
|
||||
Nodes: make([]Node, 0),
|
||||
},
|
||||
},
|
||||
depth: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package ssh_config
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -133,14 +135,75 @@ func TestGetEqsign(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var includeFile = []byte(`
|
||||
# This host should not exist, so we can use it for test purposes / it won't
|
||||
# interfere with any other configurations.
|
||||
Host kevinburke.ssh_config.test.example.com
|
||||
Port 4567
|
||||
`)
|
||||
|
||||
func TestInclude(t *testing.T) {
|
||||
t.Skip("can't currently handle Include directives")
|
||||
if testing.Short() {
|
||||
t.Skip("skipping fs write in short mode")
|
||||
}
|
||||
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file")
|
||||
err := ioutil.WriteFile(testPath, includeFile, 0644)
|
||||
if err != nil {
|
||||
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||||
}
|
||||
defer os.Remove(testPath)
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/include"),
|
||||
}
|
||||
val := us.Get("test.test", "Compression")
|
||||
if val != "yes" {
|
||||
t.Errorf("expected to find Compression=yes in included file, got %q", val)
|
||||
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
|
||||
if val != "4567" {
|
||||
t.Errorf("expected to find Port=4567 in included file, got %q", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncludeSystem(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping fs write in short mode")
|
||||
}
|
||||
testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file")
|
||||
err := ioutil.WriteFile(testPath, includeFile, 0644)
|
||||
if err != nil {
|
||||
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||||
}
|
||||
defer os.Remove(testPath)
|
||||
us := &UserSettings{
|
||||
systemConfigFinder: testConfigFinder("testdata/include"),
|
||||
}
|
||||
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
|
||||
if val != "4567" {
|
||||
t.Errorf("expected to find Port=4567 in included file, got %q", val)
|
||||
}
|
||||
}
|
||||
|
||||
var recursiveIncludeFile = []byte(`
|
||||
Host kevinburke.ssh_config.test.example.com
|
||||
Include kevinburke-ssh-config-recursive-include
|
||||
`)
|
||||
|
||||
func TestIncludeRecursive(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping fs write in short mode")
|
||||
}
|
||||
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include")
|
||||
err := ioutil.WriteFile(testPath, recursiveIncludeFile, 0644)
|
||||
if err != nil {
|
||||
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||||
}
|
||||
defer os.Remove(testPath)
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/include-recursive"),
|
||||
}
|
||||
val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port")
|
||||
if err != ErrDepthExceeded {
|
||||
t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err)
|
||||
}
|
||||
if val != "" {
|
||||
t.Errorf("non-empty string value %s", val)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
38
parser.go
38
parser.go
@@ -11,15 +11,26 @@ type sshParser struct {
|
||||
tokensBuffer []token
|
||||
currentTable []string
|
||||
seenTableKeys []string
|
||||
// /etc/ssh parser or local parser - used to find the default for relative
|
||||
// filepaths in the Include directive
|
||||
system bool
|
||||
depth uint8
|
||||
}
|
||||
|
||||
type sshParserStateFn func() sshParserStateFn
|
||||
|
||||
// Formats and panics an error message based on a token
|
||||
func (p *sshParser) raiseError(tok *token, msg string, args ...interface{}) {
|
||||
func (p *sshParser) raiseErrorf(tok *token, msg string, args ...interface{}) {
|
||||
panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...))
|
||||
}
|
||||
|
||||
func (p *sshParser) raiseError(tok *token, err error) {
|
||||
if err == ErrDepthExceeded {
|
||||
panic(err)
|
||||
}
|
||||
panic(tok.Position.String() + ": " + err.Error())
|
||||
}
|
||||
|
||||
func (p *sshParser) run() {
|
||||
for state := p.parseStart; state != nil; {
|
||||
state = state()
|
||||
@@ -68,7 +79,7 @@ func (p *sshParser) parseStart() sshParserStateFn {
|
||||
case tokenEOF:
|
||||
return nil
|
||||
default:
|
||||
p.raiseError(tok, fmt.Sprintf("unexpected token %q\n", tok))
|
||||
p.raiseErrorf(tok, fmt.Sprintf("unexpected token %q\n", tok))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -98,7 +109,7 @@ func (p *sshParser) parseKV() sshParserStateFn {
|
||||
for i := range strPatterns {
|
||||
pat, err := NewPattern(strPatterns[i])
|
||||
if err != nil {
|
||||
p.raiseError(val, "Invalid host pattern: %v", err)
|
||||
p.raiseErrorf(val, "Invalid host pattern: %v", err)
|
||||
return nil
|
||||
}
|
||||
patterns[i] = pat
|
||||
@@ -112,6 +123,19 @@ func (p *sshParser) parseKV() sshParserStateFn {
|
||||
return p.parseStart
|
||||
}
|
||||
lastHost := p.config.Hosts[len(p.config.Hosts)-1]
|
||||
if key.val == "Include" {
|
||||
inc, err := NewInclude(strings.Split(val.val, " "), comment, p.system, p.depth+1)
|
||||
if err == ErrDepthExceeded {
|
||||
p.raiseError(val, err)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
p.raiseErrorf(val, "Error parsing Include directive: %v", err)
|
||||
return nil
|
||||
}
|
||||
lastHost.Nodes = append(lastHost.Nodes, inc)
|
||||
return p.parseStart
|
||||
}
|
||||
kv := &KV{
|
||||
Key: key.val,
|
||||
Value: val.val,
|
||||
@@ -140,14 +164,14 @@ func (p *sshParser) parseComment() sshParserStateFn {
|
||||
func (p *sshParser) assume(typ tokenType) {
|
||||
tok := p.peek()
|
||||
if tok == nil {
|
||||
p.raiseError(tok, "was expecting token %s, but token stream is empty", tok)
|
||||
p.raiseErrorf(tok, "was expecting token %s, but token stream is empty", tok)
|
||||
}
|
||||
if tok.typ != typ {
|
||||
p.raiseError(tok, "was expecting token %s, but got %s instead", typ, tok)
|
||||
p.raiseErrorf(tok, "was expecting token %s, but got %s instead", typ, tok)
|
||||
}
|
||||
}
|
||||
|
||||
func parseSSH(flow chan token) *Config {
|
||||
func parseSSH(flow chan token, system bool, depth uint8) *Config {
|
||||
result := newConfig()
|
||||
result.position = Position{1, 1}
|
||||
parser := &sshParser{
|
||||
@@ -156,6 +180,8 @@ func parseSSH(flow chan token) *Config {
|
||||
tokensBuffer: make([]token, 0),
|
||||
currentTable: make([]string, 0),
|
||||
seenTableKeys: make([]string, 0),
|
||||
system: system,
|
||||
depth: depth,
|
||||
}
|
||||
parser.run()
|
||||
return result
|
||||
|
||||
6
testdata/include
vendored
6
testdata/include
vendored
@@ -1,2 +1,4 @@
|
||||
Host test.test
|
||||
Include anotherfile
|
||||
Host kevinburke.ssh_config.test.example.com
|
||||
# This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on
|
||||
# the test.
|
||||
Include kevinburke-ssh-config-*-file
|
||||
|
||||
4
testdata/include-recursive
vendored
Normal file
4
testdata/include-recursive
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
Host kevinburke.ssh_config.test.example.com
|
||||
# This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on
|
||||
# the test. It should include itself.
|
||||
Include kevinburke-ssh-config-recursive-include
|
||||
0
testdata/system-include
vendored
Normal file
0
testdata/system-include
vendored
Normal file
Reference in New Issue
Block a user