Support retrieving multiple values for a given keyword
IdentityFile among others supports being provided multiple times and aggregated across, potentially, multiple files. Support that workflow by adding GetAll and GetAllStrict alongside the current functions.
This commit is contained in:
committed by
Kevin Burke
parent
42c0635e2f
commit
124166206d
@@ -17,6 +17,14 @@ want to retrieve.
|
||||
port := ssh_config.Get("myhost", "Port")
|
||||
```
|
||||
|
||||
Certain directives can occur multiple times for a host (such as `IdentityFile`),
|
||||
so you should use the `GetAll` or `GetAllStrict` directive to retrieve those
|
||||
instead.
|
||||
|
||||
```go
|
||||
files := ssh_config.GetAll("myhost", "IdentityFile")
|
||||
```
|
||||
|
||||
You can also load a config file and read values from it.
|
||||
|
||||
```go
|
||||
|
||||
174
config.go
174
config.go
@@ -102,6 +102,13 @@ func findVal(c *Config, alias, key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func findAll(c *Config, alias, key string) ([]string, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return c.GetAll(alias, key)
|
||||
}
|
||||
|
||||
// Get finds the first value for key within a declaration that matches the
|
||||
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
|
||||
// is false and we could not parse the configuration file. Use GetStrict to
|
||||
@@ -114,19 +121,51 @@ func Get(alias, key string) string {
|
||||
return DefaultUserSettings.Get(alias, key)
|
||||
}
|
||||
|
||||
// GetAll retrieves zero or more directives for key for the given alias. GetAll
|
||||
// returns nil if no value was found, or if IgnoreErrors is false and we could
|
||||
// not parse the configuration file. Use GetAllStrict to disambiguate the
|
||||
// latter cases.
|
||||
//
|
||||
// In most cases you want to use Get or GetStrict, which returns a single value.
|
||||
// However, a subset of ssh configuration values (IdentityFile, for example)
|
||||
// allow you to specify multiple directives.
|
||||
//
|
||||
// The match for key is case insensitive.
|
||||
//
|
||||
// GetAll is a wrapper around DefaultUserSettings.GetAll.
|
||||
func GetAll(alias, key string) []string {
|
||||
return DefaultUserSettings.GetAll(alias, key)
|
||||
}
|
||||
|
||||
// GetStrict finds the first value for key within a declaration that matches the
|
||||
// alias. If key has a default value and no matching configuration is found, the
|
||||
// default will be returned. For more information on default values and the way
|
||||
// patterns are matched, see the manpage for ssh_config.
|
||||
//
|
||||
// error will be non-nil if and only if a user's configuration file or the
|
||||
// system configuration file could not be parsed, and u.IgnoreErrors is false.
|
||||
// The returned error will be non-nil if and only if a user's configuration file
|
||||
// or the system configuration file could not be parsed, and u.IgnoreErrors is
|
||||
// false.
|
||||
//
|
||||
// GetStrict is a wrapper around DefaultUserSettings.GetStrict.
|
||||
func GetStrict(alias, key string) (string, error) {
|
||||
return DefaultUserSettings.GetStrict(alias, key)
|
||||
}
|
||||
|
||||
// GetAllStrict retrieves zero or more directives for key for the given alias.
|
||||
//
|
||||
// In most cases you want to use Get or GetStrict, which returns a single value.
|
||||
// However, a subset of ssh configuration values (IdentityFile, for example)
|
||||
// allow you to specify multiple directives.
|
||||
//
|
||||
// The returned error will be non-nil if and only if a user's configuration file
|
||||
// or the system configuration file could not be parsed, and u.IgnoreErrors is
|
||||
// false.
|
||||
//
|
||||
// GetAllStrict is a wrapper around DefaultUserSettings.GetAllStrict.
|
||||
func GetAllStrict(alias, key string) ([]string, error) {
|
||||
return DefaultUserSettings.GetAllStrict(alias, key)
|
||||
}
|
||||
|
||||
// Get finds the first value for key within a declaration that matches the
|
||||
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
|
||||
// is false and we could not parse the configuration file. Use GetStrict to
|
||||
@@ -141,6 +180,17 @@ func (u *UserSettings) Get(alias, key string) string {
|
||||
return val
|
||||
}
|
||||
|
||||
// GetAll retrieves zero or more directives for key for the given alias. GetAll
|
||||
// returns nil if no value was found, or if IgnoreErrors is false and we could
|
||||
// not parse the configuration file. Use GetStrict to disambiguate the latter
|
||||
// cases.
|
||||
//
|
||||
// The match for key is case insensitive.
|
||||
func (u *UserSettings) GetAll(alias, key string) []string {
|
||||
val, _ := u.GetAllStrict(alias, key)
|
||||
return val
|
||||
}
|
||||
|
||||
// GetStrict finds the first value for key within a declaration that matches the
|
||||
// alias. If key has a default value and no matching configuration is found, the
|
||||
// default will be returned. For more information on default values and the way
|
||||
@@ -149,6 +199,52 @@ func (u *UserSettings) Get(alias, key string) string {
|
||||
// error will be non-nil if and only if a user's configuration file or the
|
||||
// system configuration file could not be parsed, and u.IgnoreErrors is false.
|
||||
func (u *UserSettings) GetStrict(alias, key string) (string, error) {
|
||||
u.doLoadConfigs()
|
||||
//lint:ignore S1002 I prefer it this way
|
||||
if u.onceErr != nil && u.IgnoreErrors == false {
|
||||
return "", u.onceErr
|
||||
}
|
||||
val, err := findVal(u.userConfig, alias, key)
|
||||
if err != nil || val != "" {
|
||||
return val, err
|
||||
}
|
||||
val2, err2 := findVal(u.systemConfig, alias, key)
|
||||
if err2 != nil || val2 != "" {
|
||||
return val2, err2
|
||||
}
|
||||
return Default(key), nil
|
||||
}
|
||||
|
||||
// GetAllStrict retrieves zero or more directives for key for the given alias.
|
||||
// If key has a default value and no matching configuration is found, the
|
||||
// default will be returned. For more information on default values and the way
|
||||
// patterns are matched, see the manpage for ssh_config.
|
||||
//
|
||||
// The returned error will be non-nil if and only if a user's configuration file
|
||||
// or the system configuration file could not be parsed, and u.IgnoreErrors is
|
||||
// false.
|
||||
func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) {
|
||||
u.doLoadConfigs()
|
||||
//lint:ignore S1002 I prefer it this way
|
||||
if u.onceErr != nil && u.IgnoreErrors == false {
|
||||
return nil, u.onceErr
|
||||
}
|
||||
val, err := findAll(u.userConfig, alias, key)
|
||||
if err != nil || val != nil {
|
||||
return val, err
|
||||
}
|
||||
val2, err2 := findAll(u.systemConfig, alias, key)
|
||||
if err2 != nil || val2 != nil {
|
||||
return val2, err2
|
||||
}
|
||||
// TODO: IdentityFile has multiple default values that we should return.
|
||||
if def := Default(key); def != "" {
|
||||
return []string{def}, nil
|
||||
}
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (u *UserSettings) doLoadConfigs() {
|
||||
u.loadConfigs.Do(func() {
|
||||
// can't parse user file, that's ok.
|
||||
var filename string
|
||||
@@ -176,19 +272,6 @@ func (u *UserSettings) GetStrict(alias, key string) (string, error) {
|
||||
return
|
||||
}
|
||||
})
|
||||
//lint:ignore S1002 I prefer it this way
|
||||
if u.onceErr != nil && u.IgnoreErrors == false {
|
||||
return "", u.onceErr
|
||||
}
|
||||
val, err := findVal(u.userConfig, alias, key)
|
||||
if err != nil || val != "" {
|
||||
return val, err
|
||||
}
|
||||
val2, err2 := findVal(u.systemConfig, alias, key)
|
||||
if err2 != nil || val2 != "" {
|
||||
return val2, err2
|
||||
}
|
||||
return Default(key), nil
|
||||
}
|
||||
|
||||
func parseFile(filename string) (*Config, error) {
|
||||
@@ -282,6 +365,42 @@ func (c *Config) Get(alias, key string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// GetAll returns all values in the configuration that match the alias and
|
||||
// contains key, or nil if none are present.
|
||||
func (c *Config) GetAll(alias, key string) ([]string, error) {
|
||||
lowerKey := strings.ToLower(key)
|
||||
all := []string(nil)
|
||||
for _, host := range c.Hosts {
|
||||
if !host.Matches(alias) {
|
||||
continue
|
||||
}
|
||||
for _, node := range host.Nodes {
|
||||
switch t := node.(type) {
|
||||
case *Empty:
|
||||
continue
|
||||
case *KV:
|
||||
// "keys are case insensitive" per the spec
|
||||
lkey := strings.ToLower(t.Key)
|
||||
if lkey == "match" {
|
||||
panic("can't handle Match directives")
|
||||
}
|
||||
if lkey == lowerKey {
|
||||
all = append(all, t.Value)
|
||||
}
|
||||
case *Include:
|
||||
val, _ := t.GetAll(alias, key)
|
||||
if len(val) > 0 {
|
||||
all = append(all, val...)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown Node type %v", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// String returns a string representation of the Config file.
|
||||
func (c Config) String() string {
|
||||
return marshal(c).String()
|
||||
@@ -611,6 +730,31 @@ func (inc *Include) Get(alias, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAll finds all values in the Include statement matching the alias and the
|
||||
// given key.
|
||||
func (inc *Include) GetAll(alias, key string) ([]string, error) {
|
||||
inc.mu.Lock()
|
||||
defer inc.mu.Unlock()
|
||||
var vals []string
|
||||
|
||||
// TODO: we search files in any order which is not correct
|
||||
for i := range inc.matches {
|
||||
cfg := inc.files[inc.matches[i]]
|
||||
if cfg == nil {
|
||||
panic("nil cfg")
|
||||
}
|
||||
val, err := cfg.GetAll(alias, key)
|
||||
if err == nil && len(val) != 0 {
|
||||
// In theory if SupportsMultiple was false for this key we could
|
||||
// stop looking here. But the caller has asked us to find all
|
||||
// instances of the keyword (and could use Get() if they wanted) so
|
||||
// let's keep looking.
|
||||
vals = append(vals, val...)
|
||||
}
|
||||
}
|
||||
return vals, nil
|
||||
}
|
||||
|
||||
// String prints out a string representation of this Include directive. Note
|
||||
// included Config files are not printed as part of this representation.
|
||||
func (inc *Include) String() string {
|
||||
|
||||
@@ -67,6 +67,65 @@ func TestGetWithDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllWithDefault(t *testing.T) {
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/config1"),
|
||||
}
|
||||
|
||||
val, err := us.GetAllStrict("wap", "PasswordAuthentication")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil err, got %v", err)
|
||||
}
|
||||
if len(val) != 1 || val[0] != "yes" {
|
||||
t.Errorf("expected to get PasswordAuthentication yes, got %q", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIdentities(t *testing.T) {
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/identities"),
|
||||
}
|
||||
|
||||
val, err := us.GetAllStrict("hasidentity", "IdentityFile")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil err, got %v", err)
|
||||
}
|
||||
if len(val) != 1 || val[0] != "file1" {
|
||||
t.Errorf(`expected ["file1"], got %v`, val)
|
||||
}
|
||||
|
||||
val, err = us.GetAllStrict("has2identity", "IdentityFile")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil err, got %v", err)
|
||||
}
|
||||
if len(val) != 2 || val[0] != "f1" || val[1] != "f2" {
|
||||
t.Errorf(`expected [\"f1\", \"f2\"], got %v`, val)
|
||||
}
|
||||
|
||||
val, err = us.GetAllStrict("randomhost", "IdentityFile")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil err, got %v", err)
|
||||
}
|
||||
if len(val) != len(defaultProtocol2Identities) {
|
||||
// TODO: return the right values here.
|
||||
log.Printf("expected defaults, got %v", val)
|
||||
} else {
|
||||
for i, v := range defaultProtocol2Identities {
|
||||
if val[i] != v {
|
||||
t.Errorf("invalid %d in val, expected %s got %s", i, v, val[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val, err = us.GetAllStrict("protocol1", "IdentityFile")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil err, got %v", err)
|
||||
}
|
||||
if len(val) != 1 || val[0] != "~/.ssh/identity" {
|
||||
t.Errorf("expected [\"~/.ssh/identity\"], got %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInvalidPort(t *testing.T) {
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/invalid-port"),
|
||||
@@ -98,6 +157,20 @@ func TestGetNotFoundNoDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllNotFoundNoDefault(t *testing.T) {
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/config1"),
|
||||
}
|
||||
|
||||
val, err := us.GetAllStrict("wap", "CanonicalDomains")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil err, got %v", err)
|
||||
}
|
||||
if len(val) != 0 {
|
||||
t.Errorf("expected to get CanonicalDomains '', got %q", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWildcard(t *testing.T) {
|
||||
us := &UserSettings{
|
||||
userConfigFinder: testConfigFinder("testdata/config3"),
|
||||
|
||||
11
testdata/identities
vendored
Normal file
11
testdata/identities
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
|
||||
Host hasidentity
|
||||
IdentityFile file1
|
||||
|
||||
Host has2identity
|
||||
IdentityFile f1
|
||||
IdentityFile f2
|
||||
|
||||
Host protocol1
|
||||
Protocol 1
|
||||
|
||||
@@ -160,3 +160,27 @@ var defaults = map[string]string{
|
||||
strings.ToLower("VisualHostKey"): "no",
|
||||
strings.ToLower("XAuthLocation"): "/usr/X11R6/bin/xauth",
|
||||
}
|
||||
|
||||
// these identities are used for SSH protocol 2
|
||||
var defaultProtocol2Identities = []string{
|
||||
"~/.ssh/id_dsa",
|
||||
"~/.ssh/id_ecdsa",
|
||||
"~/.ssh/id_ed25519",
|
||||
"~/.ssh/id_rsa",
|
||||
}
|
||||
|
||||
// these directives support multiple items that can be collected
|
||||
// across multiple files
|
||||
var pluralDirectives = map[string]bool{
|
||||
"CertificateFile": true,
|
||||
"IdentityFile": true,
|
||||
"DynamicForward": true,
|
||||
"RemoteForward": true,
|
||||
"SendEnv": true,
|
||||
"SetEnv": true,
|
||||
}
|
||||
|
||||
// SupportsMultiple reports whether a directive can be specified multiple times.
|
||||
func SupportsMultiple(key string) bool {
|
||||
return pluralDirectives[strings.ToLower(key)]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user