Support retrieving multiple values for a given keyword

IdentityFile among others supports being provided multiple times and
aggregated across, potentially, multiple files. Support that workflow
by adding GetAll and GetAllStrict alongside the current functions.
This commit is contained in:
Dustin Spicuzza
2018-07-11 00:37:12 -04:00
committed by Kevin Burke
parent 42c0635e2f
commit 124166206d
5 changed files with 275 additions and 15 deletions

View File

@@ -17,6 +17,14 @@ want to retrieve.
port := ssh_config.Get("myhost", "Port")
```
Certain directives can occur multiple times for a host (such as `IdentityFile`),
so you should use the `GetAll` or `GetAllStrict` directive to retrieve those
instead.
```go
files := ssh_config.GetAll("myhost", "IdentityFile")
```
You can also load a config file and read values from it.
```go

174
config.go
View File

@@ -102,6 +102,13 @@ func findVal(c *Config, alias, key string) (string, error) {
return val, nil
}
func findAll(c *Config, alias, key string) ([]string, error) {
if c == nil {
return nil, nil
}
return c.GetAll(alias, key)
}
// Get finds the first value for key within a declaration that matches the
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
// is false and we could not parse the configuration file. Use GetStrict to
@@ -114,19 +121,51 @@ func Get(alias, key string) string {
return DefaultUserSettings.Get(alias, key)
}
// GetAll retrieves zero or more directives for key for the given alias. GetAll
// returns nil if no value was found, or if IgnoreErrors is false and we could
// not parse the configuration file. Use GetAllStrict to disambiguate the
// latter cases.
//
// In most cases you want to use Get or GetStrict, which returns a single value.
// However, a subset of ssh configuration values (IdentityFile, for example)
// allow you to specify multiple directives.
//
// The match for key is case insensitive.
//
// GetAll is a wrapper around DefaultUserSettings.GetAll.
func GetAll(alias, key string) []string {
return DefaultUserSettings.GetAll(alias, key)
}
// GetStrict finds the first value for key within a declaration that matches the
// alias. If key has a default value and no matching configuration is found, the
// default will be returned. For more information on default values and the way
// patterns are matched, see the manpage for ssh_config.
//
// error will be non-nil if and only if a user's configuration file or the
// system configuration file could not be parsed, and u.IgnoreErrors is false.
// The returned error will be non-nil if and only if a user's configuration file
// or the system configuration file could not be parsed, and u.IgnoreErrors is
// false.
//
// GetStrict is a wrapper around DefaultUserSettings.GetStrict.
func GetStrict(alias, key string) (string, error) {
return DefaultUserSettings.GetStrict(alias, key)
}
// GetAllStrict retrieves zero or more directives for key for the given alias.
//
// In most cases you want to use Get or GetStrict, which returns a single value.
// However, a subset of ssh configuration values (IdentityFile, for example)
// allow you to specify multiple directives.
//
// The returned error will be non-nil if and only if a user's configuration file
// or the system configuration file could not be parsed, and u.IgnoreErrors is
// false.
//
// GetAllStrict is a wrapper around DefaultUserSettings.GetAllStrict.
func GetAllStrict(alias, key string) ([]string, error) {
return DefaultUserSettings.GetAllStrict(alias, key)
}
// Get finds the first value for key within a declaration that matches the
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
// is false and we could not parse the configuration file. Use GetStrict to
@@ -141,6 +180,17 @@ func (u *UserSettings) Get(alias, key string) string {
return val
}
// GetAll retrieves zero or more directives for key for the given alias. GetAll
// returns nil if no value was found, or if IgnoreErrors is false and we could
// not parse the configuration file. Use GetStrict to disambiguate the latter
// cases.
//
// The match for key is case insensitive.
func (u *UserSettings) GetAll(alias, key string) []string {
val, _ := u.GetAllStrict(alias, key)
return val
}
// GetStrict finds the first value for key within a declaration that matches the
// alias. If key has a default value and no matching configuration is found, the
// default will be returned. For more information on default values and the way
@@ -149,6 +199,52 @@ func (u *UserSettings) Get(alias, key string) string {
// error will be non-nil if and only if a user's configuration file or the
// system configuration file could not be parsed, and u.IgnoreErrors is false.
func (u *UserSettings) GetStrict(alias, key string) (string, error) {
u.doLoadConfigs()
//lint:ignore S1002 I prefer it this way
if u.onceErr != nil && u.IgnoreErrors == false {
return "", u.onceErr
}
val, err := findVal(u.userConfig, alias, key)
if err != nil || val != "" {
return val, err
}
val2, err2 := findVal(u.systemConfig, alias, key)
if err2 != nil || val2 != "" {
return val2, err2
}
return Default(key), nil
}
// GetAllStrict retrieves zero or more directives for key for the given alias.
// If key has a default value and no matching configuration is found, the
// default will be returned. For more information on default values and the way
// patterns are matched, see the manpage for ssh_config.
//
// The returned error will be non-nil if and only if a user's configuration file
// or the system configuration file could not be parsed, and u.IgnoreErrors is
// false.
func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) {
u.doLoadConfigs()
//lint:ignore S1002 I prefer it this way
if u.onceErr != nil && u.IgnoreErrors == false {
return nil, u.onceErr
}
val, err := findAll(u.userConfig, alias, key)
if err != nil || val != nil {
return val, err
}
val2, err2 := findAll(u.systemConfig, alias, key)
if err2 != nil || val2 != nil {
return val2, err2
}
// TODO: IdentityFile has multiple default values that we should return.
if def := Default(key); def != "" {
return []string{def}, nil
}
return []string{}, nil
}
func (u *UserSettings) doLoadConfigs() {
u.loadConfigs.Do(func() {
// can't parse user file, that's ok.
var filename string
@@ -176,19 +272,6 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) {
return
}
})
//lint:ignore S1002 I prefer it this way
if u.onceErr != nil && u.IgnoreErrors == false {
return "", u.onceErr
}
val, err := findVal(u.userConfig, alias, key)
if err != nil || val != "" {
return val, err
}
val2, err2 := findVal(u.systemConfig, alias, key)
if err2 != nil || val2 != "" {
return val2, err2
}
return Default(key), nil
}
func parseFile(filename string) (*Config, error) {
@@ -282,6 +365,42 @@ func (c *Config) Get(alias, key string) (string, error) {
return "", nil
}
// GetAll returns all values in the configuration that match the alias and
// contains key, or nil if none are present.
func (c *Config) GetAll(alias, key string) ([]string, error) {
lowerKey := strings.ToLower(key)
all := []string(nil)
for _, host := range c.Hosts {
if !host.Matches(alias) {
continue
}
for _, node := range host.Nodes {
switch t := node.(type) {
case *Empty:
continue
case *KV:
// "keys are case insensitive" per the spec
lkey := strings.ToLower(t.Key)
if lkey == "match" {
panic("can't handle Match directives")
}
if lkey == lowerKey {
all = append(all, t.Value)
}
case *Include:
val, _ := t.GetAll(alias, key)
if len(val) > 0 {
all = append(all, val...)
}
default:
return nil, fmt.Errorf("unknown Node type %v", t)
}
}
}
return all, nil
}
// String returns a string representation of the Config file.
func (c Config) String() string {
return marshal(c).String()
@@ -611,6 +730,31 @@ func (inc *Include) Get(alias, key string) string {
return ""
}
// GetAll finds all values in the Include statement matching the alias and the
// given key.
func (inc *Include) GetAll(alias, key string) ([]string, error) {
inc.mu.Lock()
defer inc.mu.Unlock()
var vals []string
// TODO: we search files in any order which is not correct
for i := range inc.matches {
cfg := inc.files[inc.matches[i]]
if cfg == nil {
panic("nil cfg")
}
val, err := cfg.GetAll(alias, key)
if err == nil && len(val) != 0 {
// In theory if SupportsMultiple was false for this key we could
// stop looking here. But the caller has asked us to find all
// instances of the keyword (and could use Get() if they wanted) so
// let's keep looking.
vals = append(vals, val...)
}
}
return vals, nil
}
// String prints out a string representation of this Include directive. Note
// included Config files are not printed as part of this representation.
func (inc *Include) String() string {

View File

@@ -67,6 +67,65 @@ func TestGetWithDefault(t *testing.T) {
}
}
func TestGetAllWithDefault(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val, err := us.GetAllStrict("wap", "PasswordAuthentication")
if err != nil {
t.Fatalf("expected nil err, got %v", err)
}
if len(val) != 1 || val[0] != "yes" {
t.Errorf("expected to get PasswordAuthentication yes, got %q", val)
}
}
func TestGetIdentities(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/identities"),
}
val, err := us.GetAllStrict("hasidentity", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != 1 || val[0] != "file1" {
t.Errorf(`expected ["file1"], got %v`, val)
}
val, err = us.GetAllStrict("has2identity", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != 2 || val[0] != "f1" || val[1] != "f2" {
t.Errorf(`expected [\"f1\", \"f2\"], got %v`, val)
}
val, err = us.GetAllStrict("randomhost", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != len(defaultProtocol2Identities) {
// TODO: return the right values here.
log.Printf("expected defaults, got %v", val)
} else {
for i, v := range defaultProtocol2Identities {
if val[i] != v {
t.Errorf("invalid %d in val, expected %s got %s", i, v, val[i])
}
}
}
val, err = us.GetAllStrict("protocol1", "IdentityFile")
if err != nil {
t.Errorf("expected nil err, got %v", err)
}
if len(val) != 1 || val[0] != "~/.ssh/identity" {
t.Errorf("expected [\"~/.ssh/identity\"], got %v", val)
}
}
func TestGetInvalidPort(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/invalid-port"),
@@ -98,6 +157,20 @@ func TestGetNotFoundNoDefault(t *testing.T) {
}
}
func TestGetAllNotFoundNoDefault(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config1"),
}
val, err := us.GetAllStrict("wap", "CanonicalDomains")
if err != nil {
t.Fatalf("expected nil err, got %v", err)
}
if len(val) != 0 {
t.Errorf("expected to get CanonicalDomains '', got %q", val)
}
}
func TestGetWildcard(t *testing.T) {
us := &UserSettings{
userConfigFinder: testConfigFinder("testdata/config3"),

11
testdata/identities vendored Normal file
View File

@@ -0,0 +1,11 @@
Host hasidentity
IdentityFile file1
Host has2identity
IdentityFile f1
IdentityFile f2
Host protocol1
Protocol 1

View File

@@ -160,3 +160,27 @@ var defaults = map[string]string{
strings.ToLower("VisualHostKey"): "no",
strings.ToLower("XAuthLocation"): "/usr/X11R6/bin/xauth",
}
// these identities are used for SSH protocol 2
var defaultProtocol2Identities = []string{
"~/.ssh/id_dsa",
"~/.ssh/id_ecdsa",
"~/.ssh/id_ed25519",
"~/.ssh/id_rsa",
}
// these directives support multiple items that can be collected
// across multiple files
var pluralDirectives = map[string]bool{
"CertificateFile": true,
"IdentityFile": true,
"DynamicForward": true,
"RemoteForward": true,
"SendEnv": true,
"SetEnv": true,
}
// SupportsMultiple reports whether a directive can be specified multiple times.
func SupportsMultiple(key string) bool {
return pluralDirectives[strings.ToLower(key)]
}