all: rewrite the lexer to consume the entire input first
Previously we used the buffruneio package to buffer input. However, the error handling was not good, and we would often panic when parsing inputs. SSH config files are generally not large, on the order of kilobytes or megabytes, and it's fine to just read the entire thing into memory and then parse from there. This also simplifies the parser significantly and lets us remove a dependency and several defer calls. Add a test that panicked with the old version and then modify the code to ensure the test no longer panics. Thanks to Mark Nevill (@devnev) for the initial error report and failing test case. Fixes #10. Fixes #24.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +0,0 @@
|
|||||||
/bazel-*
|
|
||||||
|
|||||||
18
config.go
18
config.go
@@ -34,6 +34,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
osuser "os/user"
|
osuser "os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -195,26 +196,29 @@ func parseFile(filename string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseWithDepth(filename string, depth uint8) (*Config, error) {
|
func parseWithDepth(filename string, depth uint8) (*Config, error) {
|
||||||
f, err := os.Open(filename)
|
b, err := ioutil.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
return decodeBytes(b, isSystem(filename), depth)
|
||||||
return decode(f, isSystem(filename), depth)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSystem(filename string) bool {
|
func isSystem(filename string) bool {
|
||||||
// TODO i'm not sure this is the best way to detect a system repo
|
// TODO: not sure this is the best way to detect a system repo
|
||||||
return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
|
return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode reads r into a Config, or returns an error if r could not be parsed as
|
// Decode reads r into a Config, or returns an error if r could not be parsed as
|
||||||
// an SSH config file.
|
// an SSH config file.
|
||||||
func Decode(r io.Reader) (*Config, error) {
|
func Decode(r io.Reader) (*Config, error) {
|
||||||
return decode(r, false, 0)
|
b, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decodeBytes(b, false, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) {
|
func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
if _, ok := r.(runtime.Error); ok {
|
if _, ok := r.(runtime.Error); ok {
|
||||||
@@ -228,7 +232,7 @@ func decode(r io.Reader, system bool, depth uint8) (c *Config, err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c = parseSSH(lexSSH(r), system, depth)
|
c = parseSSH(lexSSH(b), system, depth)
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
37
lexer.go
37
lexer.go
@@ -1,16 +1,16 @@
|
|||||||
package ssh_config
|
package ssh_config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"bytes"
|
||||||
|
|
||||||
buffruneio "github.com/pelletier/go-buffruneio"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Define state functions
|
// Define state functions
|
||||||
type sshLexStateFn func() sshLexStateFn
|
type sshLexStateFn func() sshLexStateFn
|
||||||
|
|
||||||
type sshLexer struct {
|
type sshLexer struct {
|
||||||
input *buffruneio.Reader // Textual source
|
inputIdx int
|
||||||
|
input []rune // Textual source
|
||||||
|
|
||||||
buffer []rune // Runes composing the current token
|
buffer []rune // Runes composing the current token
|
||||||
tokens chan token
|
tokens chan token
|
||||||
line int
|
line int
|
||||||
@@ -114,16 +114,14 @@ func (s *sshLexer) lexRvalue() sshLexStateFn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *sshLexer) read() rune {
|
func (s *sshLexer) read() rune {
|
||||||
r, _, err := s.input.ReadRune()
|
r := s.peek()
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if r == '\n' {
|
if r == '\n' {
|
||||||
s.endbufferLine++
|
s.endbufferLine++
|
||||||
s.endbufferCol = 1
|
s.endbufferCol = 1
|
||||||
} else {
|
} else {
|
||||||
s.endbufferCol++
|
s.endbufferCol++
|
||||||
}
|
}
|
||||||
|
s.inputIdx++
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,21 +195,22 @@ func (s *sshLexer) emitWithValue(t tokenType, value string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *sshLexer) peek() rune {
|
func (s *sshLexer) peek() rune {
|
||||||
r, _, err := s.input.ReadRune()
|
if s.inputIdx >= len(s.input) {
|
||||||
if err != nil {
|
return eof
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
s.input.UnreadRune()
|
|
||||||
|
r := s.input[s.inputIdx]
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sshLexer) follow(next string) bool {
|
func (s *sshLexer) follow(next string) bool {
|
||||||
|
inputIdx := s.inputIdx
|
||||||
for _, expectedRune := range next {
|
for _, expectedRune := range next {
|
||||||
r, _, err := s.input.ReadRune()
|
if inputIdx >= len(s.input) {
|
||||||
defer s.input.UnreadRune()
|
return false
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
|
r := s.input[inputIdx]
|
||||||
|
inputIdx++
|
||||||
if expectedRune != r {
|
if expectedRune != r {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -226,10 +225,10 @@ func (s *sshLexer) run() {
|
|||||||
close(s.tokens)
|
close(s.tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func lexSSH(input io.Reader) chan token {
|
func lexSSH(input []byte) chan token {
|
||||||
bufferedInput := buffruneio.NewReader(input)
|
runes := bytes.Runes(input)
|
||||||
l := &sshLexer{
|
l := &sshLexer{
|
||||||
input: bufferedInput,
|
input: runes,
|
||||||
tokens: make(chan token),
|
tokens: make(chan token),
|
||||||
line: 1,
|
line: 1,
|
||||||
col: 1,
|
col: 1,
|
||||||
|
|||||||
@@ -169,6 +169,12 @@ func (p *sshParser) parseComment() sshParserStateFn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseSSH(flow chan token, system bool, depth uint8) *Config {
|
func parseSSH(flow chan token, system bool, depth uint8) *Config {
|
||||||
|
// Ensure we consume tokens to completion even if parser exits early
|
||||||
|
defer func() {
|
||||||
|
for range flow {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
result := newConfig()
|
result := newConfig()
|
||||||
result.position = Position{1, 1}
|
result.position = Position{1, 1}
|
||||||
parser := &sshParser{
|
parser := &sshParser{
|
||||||
|
|||||||
24
parser_test.go
Normal file
24
parser_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package ssh_config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type errReader struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *errReader) Read(p []byte) (n int, err error) {
|
||||||
|
return 0, errors.New("read error occurred")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIOError(t *testing.T) {
|
||||||
|
buf := &errReader{}
|
||||||
|
_, err := Decode(buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected non-nil err, got nil")
|
||||||
|
}
|
||||||
|
if err.Error() != "read error occurred" {
|
||||||
|
t.Errorf("expected read error msg, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user