Skip to content

Commit a50f403

Browse files
committed
core: public release
1 parent 48fbd01 commit a50f403

File tree

9 files changed

+697
-0
lines changed

9 files changed

+697
-0
lines changed

go.mod

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module github.com/phenpessoa/sql2go
2+
3+
go 1.21.4

internal/parser/evaluator.go

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package parser
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"reflect"
8+
"strings"
9+
)
10+
11+
func Parse[T any](dst *T, r io.Reader) error {
12+
v := reflect.ValueOf(dst).Elem()
13+
14+
data, _ := io.ReadAll(r)
15+
input := string(data)
16+
l := newLexer(input)
17+
p := newParser(l)
18+
tree := p.parse()
19+
20+
i := 0
21+
for i < len(tree.nodes) {
22+
n := tree.nodes[i]
23+
switch t := n.(type) {
24+
case nodeName:
25+
if !t.valid {
26+
return errors.New("sql2go: found an empty name")
27+
}
28+
29+
field := v.FieldByName(t.val)
30+
if !field.IsValid() || !field.CanSet() || !field.CanInterface() {
31+
return fmt.Errorf(
32+
"sql2go: field not found or invalid in dst struct: %s",
33+
t.val,
34+
)
35+
}
36+
37+
if _, ok := field.Interface().(string); !ok {
38+
return fmt.Errorf(
39+
"sql2go: field %s is not of type string", t.val,
40+
)
41+
}
42+
43+
var (
44+
query strings.Builder
45+
lastByte byte
46+
)
47+
i++
48+
for _, nn := range tree.nodes[i:] {
49+
switch t := nn.(type) {
50+
case nodeEnfOfQuery:
51+
query.Grow(1)
52+
query.WriteByte(';')
53+
lastByte = ';'
54+
case nodeName:
55+
goto out
56+
case nodeQuery:
57+
val := strings.TrimSpace(t.val)
58+
if lastByte == '\'' || lastByte == '"' || lastByte == '`' {
59+
query.Grow(len(val) + 1)
60+
query.WriteRune(' ')
61+
} else {
62+
query.Grow(len(val))
63+
}
64+
query.WriteString(val)
65+
lastByte = val[len(val)-1]
66+
case nodeStringLiteral:
67+
if lastByte != '\n' && lastByte != ' ' {
68+
query.Grow(len(t.val) + 3)
69+
query.WriteByte(' ')
70+
} else {
71+
query.Grow(len(t.val) + 2)
72+
}
73+
query.WriteByte('\'')
74+
query.WriteString(t.val)
75+
query.WriteByte('\'')
76+
lastByte = '\''
77+
case nodeIdentifier:
78+
if lastByte != '\n' && lastByte != ' ' {
79+
query.Grow(len(t.val) + 3)
80+
query.WriteByte(' ')
81+
} else {
82+
query.Grow(len(t.val) + 2)
83+
}
84+
query.WriteString(t.tok.literal)
85+
query.WriteString(t.val)
86+
query.WriteString(t.tok.literal)
87+
lastByte = t.tok.literal[0]
88+
case nodeNewLine:
89+
if lastByte != '\n' {
90+
query.Grow(1)
91+
query.WriteByte('\n')
92+
}
93+
lastByte = '\n'
94+
}
95+
96+
i++
97+
}
98+
99+
out:
100+
field.Set(reflect.ValueOf(strings.TrimSpace(query.String())))
101+
default:
102+
i++
103+
}
104+
}
105+
106+
return nil
107+
}

internal/parser/evaluator_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package parser
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"github.com/phenpessoa/sql2go/internal/testdata"
8+
)
9+
10+
func TestParser(t *testing.T) {
11+
type queries struct {
12+
Foo string
13+
Bar string
14+
Baz string
15+
Qux string
16+
Quux string
17+
Corge string
18+
Grault string
19+
HardToLex string
20+
Empty string
21+
Garply string
22+
Waldo string
23+
Fred string
24+
Whatif string
25+
WhatAboutThis string
26+
WhatAboutThis2 string
27+
Plugh string
28+
Xyzzy string
29+
Thud string
30+
}
31+
32+
want := queries{
33+
Foo: "SELECT * FROM foo;",
34+
Bar: "SELECT * FROM bar\nWHERE id = 123;",
35+
Baz: "SELECT\n*\nFROM\nbaz\nWHERE\nbaz = 123 AND\nbaz = baz;",
36+
Qux: "SELECT * FROM qux;",
37+
Quux: "SELECT * FROM quux\nWHERE quux = 123;",
38+
Corge: "SELECT '--' FROM corge;",
39+
Grault: "SELECT '\n-- name: Grault\n' FROM grault;",
40+
HardToLex: "SELECT;",
41+
Empty: "",
42+
Garply: "SELECT 'garply-string-literal' FROM garply;",
43+
Waldo: "SELECT \"waldo_identifier_1\" FROM waldo;",
44+
Fred: "SELECT `fred_identifier_2` FROM fred;",
45+
Whatif: "SELECT * FROM whatif;",
46+
WhatAboutThis: "SELECT 'foo--hard--string--literal' FROM whatAboutThis; `foo\"_identifier_3'`",
47+
WhatAboutThis2: "SELECT\n`foo\"_identifier_4`",
48+
Plugh: "SELECT * FROM plugh",
49+
Xyzzy: "SELECT * FROM xyzzy",
50+
Thud: "SELECT * FROM thud;\nSELECT * FROM thud2;",
51+
}
52+
53+
f, err := testdata.TestFS.Open("files/initial.sql")
54+
if err != nil {
55+
t.Errorf("failed to open inital.sql: %s", err)
56+
t.FailNow()
57+
return
58+
}
59+
60+
var got queries
61+
if err := Parse(&got, f); err != nil {
62+
t.Errorf("failed to parse initial.sql: %s", err)
63+
t.FailNow()
64+
return
65+
}
66+
67+
if got != want {
68+
t.Error("initial.sql not parsed properly\n")
69+
typ := reflect.TypeOf(got)
70+
for i := 0; i < typ.NumField(); i++ {
71+
f := typ.Field(i)
72+
fv1 := reflect.ValueOf(got).Field(i).Interface()
73+
fv2 := reflect.ValueOf(want).Field(i).Interface()
74+
75+
if fv1 == fv2 {
76+
continue
77+
}
78+
79+
t.Errorf(
80+
"field: %s\nwanted: %#+v\ngot: %#+v\n",
81+
f.Name, fv2, fv1,
82+
)
83+
}
84+
85+
t.FailNow()
86+
return
87+
}
88+
}

internal/parser/lexer.go

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package parser
2+
3+
import (
4+
"strings"
5+
)
6+
7+
type lexer struct {
8+
input string
9+
pos int
10+
readPos int
11+
ch byte
12+
}
13+
14+
func newLexer(input string) *lexer {
15+
l := &lexer{input: strings.ReplaceAll(input, "\r\n", "\n")}
16+
l.readChar()
17+
return l
18+
}
19+
20+
func (l *lexer) readChar() {
21+
if l.readPos >= len(l.input) {
22+
l.ch = 0
23+
} else {
24+
l.ch = l.input[l.readPos]
25+
}
26+
l.pos = l.readPos
27+
l.readPos++
28+
}
29+
30+
func (l *lexer) moveBack() {
31+
if l.pos > 0 && l.pos < len(l.input) {
32+
l.readPos = l.pos
33+
l.pos--
34+
l.ch = l.input[l.readPos]
35+
} else {
36+
l.ch = 0
37+
}
38+
}
39+
40+
func (l *lexer) readLine() string {
41+
pos := l.pos
42+
for l.ch != '\n' {
43+
l.readChar()
44+
}
45+
return l.input[pos:l.pos]
46+
}
47+
48+
func (l *lexer) peekChar() byte {
49+
if l.readPos >= len(l.input) {
50+
return 0
51+
}
52+
return l.input[l.readPos]
53+
}
54+
55+
const (
56+
nameBytes = " name: "
57+
)
58+
59+
// isName detects if we are in a name token,
60+
// if true it will consume the bytes
61+
func (l *lexer) isName() bool {
62+
counter := 0
63+
for counter < len(nameBytes) &&
64+
l.input[l.readPos+counter] == nameBytes[counter] {
65+
counter++
66+
}
67+
if counter == len(nameBytes) {
68+
for i := 0; i < len(nameBytes); i++ {
69+
l.readChar()
70+
}
71+
return true
72+
}
73+
return false
74+
}
75+
76+
func (l *lexer) skipWhitespace() {
77+
for l.ch == ' ' || l.ch == '\t' || l.ch == '\r' {
78+
l.readChar()
79+
}
80+
}
81+
82+
func (l *lexer) readRawInput() string {
83+
pos := l.pos
84+
outer:
85+
for {
86+
switch l.ch {
87+
case '-':
88+
if l.peekChar() == '-' {
89+
break outer
90+
}
91+
case ';', '\'', '"', '`':
92+
break outer
93+
case '\n', 0:
94+
break outer
95+
}
96+
l.readChar()
97+
}
98+
data := l.input[pos:l.pos]
99+
l.moveBack()
100+
return data
101+
}
102+
103+
func (l *lexer) nextToken() token {
104+
var t token
105+
106+
l.skipWhitespace()
107+
108+
switch l.ch {
109+
case '-':
110+
if l.peekChar() == '-' {
111+
l.readChar()
112+
if l.isName() {
113+
t.literal = "-- name: "
114+
t.typ = tokenTypeName
115+
} else {
116+
t.literal = "--"
117+
t.typ = tokenTypeComment
118+
}
119+
} else {
120+
t.typ = tokenTypeUndefined
121+
}
122+
case '\n':
123+
t.typ = tokenTypeNewLine
124+
t.literal = "\n"
125+
case '"', '`':
126+
t.typ = tokenTypeIdentifier
127+
t.literal = string(l.ch)
128+
case ';':
129+
t.typ = tokenTypeSemicolon
130+
t.literal = ";"
131+
case 0:
132+
t.typ = tokenTypeEOF
133+
case '\'':
134+
t.typ = tokenTypeStringLiteral
135+
t.literal = "'"
136+
default:
137+
t.typ = tokenTypeRawInput
138+
t.literal = l.readRawInput()
139+
}
140+
141+
l.readChar()
142+
return t
143+
}

0 commit comments

Comments
 (0)