diff --git a/.gitignore b/.gitignore index a6ef824..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +0,0 @@ -/bazel-* diff --git a/config.go b/config.go index 9f97ed6..297b25f 100644 --- a/config.go +++ b/config.go @@ -34,6 +34,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "os" osuser "os/user" "path/filepath" @@ -195,26 +196,29 @@ func parseFile(filename string) (*Config, error) { } func parseWithDepth(filename string, depth uint8) (*Config, error) { - f, err := os.Open(filename) + b, err := ioutil.ReadFile(filename) if err != nil { return nil, err } - defer f.Close() - return decode(f, isSystem(filename), depth) + return decodeBytes(b, isSystem(filename), depth) } func isSystem(filename string) bool { - // TODO i'm not sure this is the best way to detect a system repo + // TODO: not sure this is the best way to detect a system repo return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh") } // Decode reads r into a Config, or returns an error if r could not be parsed as // an SSH config file. func Decode(r io.Reader) (*Config, error) { - return decode(r, false, 0) + b, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + return decodeBytes(b, false, 0) } -func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) { +func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) { defer func() { if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { @@ -228,7 +232,7 @@ func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) { } }() - c = parseSSH(lexSSH(r), system, depth) + c = parseSSH(lexSSH(b), system, depth) return c, err } diff --git a/lexer.go b/lexer.go index 5c1c39b..11680b4 100644 --- a/lexer.go +++ b/lexer.go @@ -1,17 +1,17 @@ package ssh_config import ( - "io" - - buffruneio "github.com/pelletier/go-buffruneio" + "bytes" ) // Define state functions type sshLexStateFn func() sshLexStateFn type sshLexer struct { - input *buffruneio.Reader // Textual source - buffer []rune // Runes composing the current token + inputIdx int + input []rune // Textual source + + buffer []rune // Runes composing the current token tokens chan token line int col int @@ -114,16 +114,14 @@ func (s *sshLexer) lexRvalue() sshLexStateFn { } func (s *sshLexer) read() rune { - r, _, err := s.input.ReadRune() - if err != nil { - panic(err) - } + r := s.peek() if r == '\n' { s.endbufferLine++ s.endbufferCol = 1 } else { s.endbufferCol++ } + s.inputIdx++ return r } @@ -197,21 +195,22 @@ func (s *sshLexer) emitWithValue(t tokenType, value string) { } func (s *sshLexer) peek() rune { - r, _, err := s.input.ReadRune() - if err != nil { - panic(err) + if s.inputIdx >= len(s.input) { + return eof } - s.input.UnreadRune() + + r := s.input[s.inputIdx] return r } func (s *sshLexer) follow(next string) bool { + inputIdx := s.inputIdx for _, expectedRune := range next { - r, _, err := s.input.ReadRune() - defer s.input.UnreadRune() - if err != nil { - panic(err) + if inputIdx >= len(s.input) { + return false } + r := s.input[inputIdx] + inputIdx++ if expectedRune != r { return false } @@ -226,10 +225,10 @@ func (s *sshLexer) run() { close(s.tokens) } -func lexSSH(input io.Reader) chan token { - bufferedInput := buffruneio.NewReader(input) +func lexSSH(input []byte) chan token { + runes := bytes.Runes(input) l := &sshLexer{ - input: bufferedInput, + input: runes, tokens: make(chan token), line: 1, col: 1, diff --git a/parser.go b/parser.go index 10a2263..36c4205 100644 --- a/parser.go +++ b/parser.go @@ -169,6 +169,12 @@ func (p *sshParser) parseComment() sshParserStateFn { } func parseSSH(flow chan token, system bool, depth uint8) *Config { + // Ensure we consume tokens to completion even if parser exits early + defer func() { + for range flow { + } + }() + result := newConfig() result.position = Position{1, 1} parser := &sshParser{ diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..ff1ab2f --- /dev/null +++ b/parser_test.go @@ -0,0 +1,24 @@ +package ssh_config + +import ( + "errors" + "testing" +) + +type errReader struct { +} + +func (b *errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error occurred") +} + +func TestIOError(t *testing.T) { + buf := &errReader{} + _, err := Decode(buf) + if err == nil { + t.Fatal("expected non-nil err, got nil") + } + if err.Error() != "read error occurred" { + t.Errorf("expected read error msg, got %v", err) + } +}