Implement Include directive
It's tricky because it involves recursive filesystem parsing, depth checking and glob matching. But figured it out. Fixes #4.
This commit is contained in:
139
config.go
139
config.go
@@ -52,15 +52,17 @@ type UserSettings struct {
|
|||||||
IgnoreErrors bool
|
IgnoreErrors bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func userConfigFinder() string {
|
func homedir() string {
|
||||||
user, err := osuser.Current()
|
user, err := osuser.Current()
|
||||||
var home string
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
home = user.HomeDir
|
return user.HomeDir
|
||||||
} else {
|
} else {
|
||||||
home = os.Getenv("HOME")
|
return os.Getenv("HOME")
|
||||||
}
|
}
|
||||||
return filepath.Join(home, ".ssh", "config")
|
}
|
||||||
|
|
||||||
|
func userConfigFinder() string {
|
||||||
|
return filepath.Join(homedir(), ".ssh", "config")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultUserSettings is the default UserSettings and is used by Get and
|
// DefaultUserSettings is the default UserSettings and is used by Get and
|
||||||
@@ -164,27 +166,44 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseFile(filename string) (*Config, error) {
|
func parseFile(filename string) (*Config, error) {
|
||||||
|
return parseWithDepth(filename, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseWithDepth(filename string, depth uint8) (*Config, error) {
|
||||||
f, err := os.Open(filename)
|
f, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
return Decode(f)
|
return decode(f, isSystem(filename), depth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSystem(filename string) bool {
|
||||||
|
// TODO i'm not sure this is the best way to detect a system repo
|
||||||
|
return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode reads r into a Config, or returns an error if r could not be parsed as
|
// Decode reads r into a Config, or returns an error if r could not be parsed as
|
||||||
// an SSH config file.
|
// an SSH config file.
|
||||||
func Decode(r io.Reader) (c *Config, err error) {
|
func Decode(r io.Reader) (*Config, error) {
|
||||||
|
return decode(r, false, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
if _, ok := r.(runtime.Error); ok {
|
if _, ok := r.(runtime.Error); ok {
|
||||||
panic(r)
|
panic(r)
|
||||||
}
|
}
|
||||||
|
if e, ok := r.(error); ok && e == ErrDepthExceeded {
|
||||||
|
err = e
|
||||||
|
return
|
||||||
|
}
|
||||||
err = errors.New(r.(string))
|
err = errors.New(r.(string))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c = parseSSH(lexSSH(r))
|
c = parseSSH(lexSSH(r), system, depth)
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,11 +213,12 @@ type Config struct {
|
|||||||
// A list of hosts to match against. The file begins with an implicit
|
// A list of hosts to match against. The file begins with an implicit
|
||||||
// "Host *" declaration matching all hosts.
|
// "Host *" declaration matching all hosts.
|
||||||
Hosts []*Host
|
Hosts []*Host
|
||||||
|
depth uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get finds the first value in the configuration that matches the alias and
|
// Get finds the first value in the configuration that matches the alias and
|
||||||
// contains key. Get returns the empty string if no value was found, the Config
|
// contains key. Get returns the empty string if no value was found, or if the
|
||||||
// contains an invalid conditional Include value.
|
// Config contains an invalid conditional Include value.
|
||||||
//
|
//
|
||||||
// The match for key is case insensitive.
|
// The match for key is case insensitive.
|
||||||
//
|
//
|
||||||
@@ -216,15 +236,17 @@ func (c *Config) Get(alias, key string) (string, error) {
|
|||||||
case *KV:
|
case *KV:
|
||||||
// "keys are case insensitive" per the spec
|
// "keys are case insensitive" per the spec
|
||||||
lkey := strings.ToLower(t.Key)
|
lkey := strings.ToLower(t.Key)
|
||||||
if lkey == "include" {
|
|
||||||
panic("can't handle Include directives")
|
|
||||||
}
|
|
||||||
if lkey == "match" {
|
if lkey == "match" {
|
||||||
panic("can't handle Match directives")
|
panic("can't handle Match directives")
|
||||||
}
|
}
|
||||||
if lkey == lowerKey {
|
if lkey == lowerKey {
|
||||||
return t.Value, nil
|
return t.Value, nil
|
||||||
}
|
}
|
||||||
|
case *Include:
|
||||||
|
val := t.Get(alias, key)
|
||||||
|
if val != "" {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unknown Node type %v", t)
|
return "", fmt.Errorf("unknown Node type %v", t)
|
||||||
}
|
}
|
||||||
@@ -437,6 +459,90 @@ func (e *Empty) String() string {
|
|||||||
return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
|
return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Include struct {
|
||||||
|
Comment string
|
||||||
|
parsed bool
|
||||||
|
// an include directive can include several different files, and wildcards
|
||||||
|
directives []string
|
||||||
|
// actual filenames are listed here
|
||||||
|
files map[string]*Config
|
||||||
|
leadingSpace uint16
|
||||||
|
position Position
|
||||||
|
depth uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxRecurseDepth = 5
|
||||||
|
|
||||||
|
var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded")
|
||||||
|
|
||||||
|
// NewInclude creates a new Include with a list of files to include.
|
||||||
|
func NewInclude(directives []string, comment string, system bool, depth uint8) (*Include, error) {
|
||||||
|
if depth > maxRecurseDepth {
|
||||||
|
return nil, ErrDepthExceeded
|
||||||
|
}
|
||||||
|
inc := &Include{
|
||||||
|
Comment: comment,
|
||||||
|
directives: directives,
|
||||||
|
files: make(map[string]*Config),
|
||||||
|
depth: depth,
|
||||||
|
}
|
||||||
|
for i := range directives {
|
||||||
|
var path string
|
||||||
|
if filepath.IsAbs(directives[i]) {
|
||||||
|
path = directives[i]
|
||||||
|
} else if system {
|
||||||
|
path = filepath.Join("/etc/ssh", directives[i])
|
||||||
|
} else {
|
||||||
|
path = filepath.Join(homedir(), ".ssh", directives[i])
|
||||||
|
}
|
||||||
|
matches, err := filepath.Glob(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for j := range matches {
|
||||||
|
if _, ok := inc.files[matches[j]]; ok {
|
||||||
|
// config already parsed
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
config, err := parseWithDepth(matches[j], depth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inc.files[matches[j]] = config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check depth limit
|
||||||
|
// for each directive:
|
||||||
|
// for each file in the list:
|
||||||
|
// - increment the depth array
|
||||||
|
// - parse it, return any errors
|
||||||
|
// - add to files map
|
||||||
|
return inc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Include) Pos() Position {
|
||||||
|
return i.position
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Include) Add() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get finds the first value in the Include statement
|
||||||
|
func (i *Include) Get(alias, key string) string {
|
||||||
|
for _, cfg := range i.files {
|
||||||
|
val, err := cfg.Get(alias, key)
|
||||||
|
if err == nil && val != "" {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Include) String() string {
|
||||||
|
return "TODO"
|
||||||
|
}
|
||||||
|
|
||||||
var matchAll *Pattern
|
var matchAll *Pattern
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -450,7 +556,12 @@ func init() {
|
|||||||
func newConfig() *Config {
|
func newConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Hosts: []*Host{
|
Hosts: []*Host{
|
||||||
&Host{implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0)},
|
&Host{
|
||||||
|
implicit: true,
|
||||||
|
Patterns: []*Pattern{matchAll},
|
||||||
|
Nodes: make([]Node, 0),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
depth: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package ssh_config
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,14 +135,75 @@ func TestGetEqsign(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var includeFile = []byte(`
|
||||||
|
# This host should not exist, so we can use it for test purposes / it won't
|
||||||
|
# interfere with any other configurations.
|
||||||
|
Host kevinburke.ssh_config.test.example.com
|
||||||
|
Port 4567
|
||||||
|
`)
|
||||||
|
|
||||||
func TestInclude(t *testing.T) {
|
func TestInclude(t *testing.T) {
|
||||||
t.Skip("can't currently handle Include directives")
|
if testing.Short() {
|
||||||
|
t.Skip("skipping fs write in short mode")
|
||||||
|
}
|
||||||
|
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file")
|
||||||
|
err := ioutil.WriteFile(testPath, includeFile, 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||||||
|
}
|
||||||
|
defer os.Remove(testPath)
|
||||||
us := &UserSettings{
|
us := &UserSettings{
|
||||||
userConfigFinder: testConfigFinder("testdata/include"),
|
userConfigFinder: testConfigFinder("testdata/include"),
|
||||||
}
|
}
|
||||||
val := us.Get("test.test", "Compression")
|
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
|
||||||
if val != "yes" {
|
if val != "4567" {
|
||||||
t.Errorf("expected to find Compression=yes in included file, got %q", val)
|
t.Errorf("expected to find Port=4567 in included file, got %q", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIncludeSystem(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping fs write in short mode")
|
||||||
|
}
|
||||||
|
testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file")
|
||||||
|
err := ioutil.WriteFile(testPath, includeFile, 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||||||
|
}
|
||||||
|
defer os.Remove(testPath)
|
||||||
|
us := &UserSettings{
|
||||||
|
systemConfigFinder: testConfigFinder("testdata/include"),
|
||||||
|
}
|
||||||
|
val := us.Get("kevinburke.ssh_config.test.example.com", "Port")
|
||||||
|
if val != "4567" {
|
||||||
|
t.Errorf("expected to find Port=4567 in included file, got %q", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var recursiveIncludeFile = []byte(`
|
||||||
|
Host kevinburke.ssh_config.test.example.com
|
||||||
|
Include kevinburke-ssh-config-recursive-include
|
||||||
|
`)
|
||||||
|
|
||||||
|
func TestIncludeRecursive(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping fs write in short mode")
|
||||||
|
}
|
||||||
|
testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include")
|
||||||
|
err := ioutil.WriteFile(testPath, recursiveIncludeFile, 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("couldn't write SSH config file: %v", err.Error())
|
||||||
|
}
|
||||||
|
defer os.Remove(testPath)
|
||||||
|
us := &UserSettings{
|
||||||
|
userConfigFinder: testConfigFinder("testdata/include-recursive"),
|
||||||
|
}
|
||||||
|
val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port")
|
||||||
|
if err != ErrDepthExceeded {
|
||||||
|
t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
if val != "" {
|
||||||
|
t.Errorf("non-empty string value %s", val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
38
parser.go
38
parser.go
@@ -11,15 +11,26 @@ type sshParser struct {
|
|||||||
tokensBuffer []token
|
tokensBuffer []token
|
||||||
currentTable []string
|
currentTable []string
|
||||||
seenTableKeys []string
|
seenTableKeys []string
|
||||||
|
// /etc/ssh parser or local parser - used to find the default for relative
|
||||||
|
// filepaths in the Include directive
|
||||||
|
system bool
|
||||||
|
depth uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
type sshParserStateFn func() sshParserStateFn
|
type sshParserStateFn func() sshParserStateFn
|
||||||
|
|
||||||
// Formats and panics an error message based on a token
|
// Formats and panics an error message based on a token
|
||||||
func (p *sshParser) raiseError(tok *token, msg string, args ...interface{}) {
|
func (p *sshParser) raiseErrorf(tok *token, msg string, args ...interface{}) {
|
||||||
panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...))
|
panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *sshParser) raiseError(tok *token, err error) {
|
||||||
|
if err == ErrDepthExceeded {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
panic(tok.Position.String() + ": " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
func (p *sshParser) run() {
|
func (p *sshParser) run() {
|
||||||
for state := p.parseStart; state != nil; {
|
for state := p.parseStart; state != nil; {
|
||||||
state = state()
|
state = state()
|
||||||
@@ -68,7 +79,7 @@ func (p *sshParser) parseStart() sshParserStateFn {
|
|||||||
case tokenEOF:
|
case tokenEOF:
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
p.raiseError(tok, fmt.Sprintf("unexpected token %q\n", tok))
|
p.raiseErrorf(tok, fmt.Sprintf("unexpected token %q\n", tok))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -98,7 +109,7 @@ func (p *sshParser) parseKV() sshParserStateFn {
|
|||||||
for i := range strPatterns {
|
for i := range strPatterns {
|
||||||
pat, err := NewPattern(strPatterns[i])
|
pat, err := NewPattern(strPatterns[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.raiseError(val, "Invalid host pattern: %v", err)
|
p.raiseErrorf(val, "Invalid host pattern: %v", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
patterns[i] = pat
|
patterns[i] = pat
|
||||||
@@ -112,6 +123,19 @@ func (p *sshParser) parseKV() sshParserStateFn {
|
|||||||
return p.parseStart
|
return p.parseStart
|
||||||
}
|
}
|
||||||
lastHost := p.config.Hosts[len(p.config.Hosts)-1]
|
lastHost := p.config.Hosts[len(p.config.Hosts)-1]
|
||||||
|
if key.val == "Include" {
|
||||||
|
inc, err := NewInclude(strings.Split(val.val, " "), comment, p.system, p.depth+1)
|
||||||
|
if err == ErrDepthExceeded {
|
||||||
|
p.raiseError(val, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
p.raiseErrorf(val, "Error parsing Include directive: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lastHost.Nodes = append(lastHost.Nodes, inc)
|
||||||
|
return p.parseStart
|
||||||
|
}
|
||||||
kv := &KV{
|
kv := &KV{
|
||||||
Key: key.val,
|
Key: key.val,
|
||||||
Value: val.val,
|
Value: val.val,
|
||||||
@@ -140,14 +164,14 @@ func (p *sshParser) parseComment() sshParserStateFn {
|
|||||||
func (p *sshParser) assume(typ tokenType) {
|
func (p *sshParser) assume(typ tokenType) {
|
||||||
tok := p.peek()
|
tok := p.peek()
|
||||||
if tok == nil {
|
if tok == nil {
|
||||||
p.raiseError(tok, "was expecting token %s, but token stream is empty", tok)
|
p.raiseErrorf(tok, "was expecting token %s, but token stream is empty", tok)
|
||||||
}
|
}
|
||||||
if tok.typ != typ {
|
if tok.typ != typ {
|
||||||
p.raiseError(tok, "was expecting token %s, but got %s instead", typ, tok)
|
p.raiseErrorf(tok, "was expecting token %s, but got %s instead", typ, tok)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSSH(flow chan token) *Config {
|
func parseSSH(flow chan token, system bool, depth uint8) *Config {
|
||||||
result := newConfig()
|
result := newConfig()
|
||||||
result.position = Position{1, 1}
|
result.position = Position{1, 1}
|
||||||
parser := &sshParser{
|
parser := &sshParser{
|
||||||
@@ -156,6 +180,8 @@ func parseSSH(flow chan token) *Config {
|
|||||||
tokensBuffer: make([]token, 0),
|
tokensBuffer: make([]token, 0),
|
||||||
currentTable: make([]string, 0),
|
currentTable: make([]string, 0),
|
||||||
seenTableKeys: make([]string, 0),
|
seenTableKeys: make([]string, 0),
|
||||||
|
system: system,
|
||||||
|
depth: depth,
|
||||||
}
|
}
|
||||||
parser.run()
|
parser.run()
|
||||||
return result
|
return result
|
||||||
|
|||||||
6
testdata/include
vendored
6
testdata/include
vendored
@@ -1,2 +1,4 @@
|
|||||||
Host test.test
|
Host kevinburke.ssh_config.test.example.com
|
||||||
Include anotherfile
|
# This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on
|
||||||
|
# the test.
|
||||||
|
Include kevinburke-ssh-config-*-file
|
||||||
|
|||||||
4
testdata/include-recursive
vendored
Normal file
4
testdata/include-recursive
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
Host kevinburke.ssh_config.test.example.com
|
||||||
|
# This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on
|
||||||
|
# the test. It should include itself.
|
||||||
|
Include kevinburke-ssh-config-recursive-include
|
||||||
0
testdata/system-include
vendored
Normal file
0
testdata/system-include
vendored
Normal file
Reference in New Issue
Block a user