Skip to content

Commit fe7e7b1

Browse files
committed
Fix type cast for ::regproc
1 parent 247a8b0 commit fe7e7b1

File tree

6 files changed

+186
-145
lines changed

6 files changed

+186
-145
lines changed

src/parser_type.go

Lines changed: 0 additions & 96 deletions
This file was deleted.

src/parser_type_cast.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package main
2+
3+
import (
4+
"strings"
5+
6+
pgQuery "github.com/pganalyze/pg_query_go/v5"
7+
)
8+
9+
type ParserTypeCast struct {
10+
utils *ParserUtils
11+
config *Config
12+
}
13+
14+
func NewParserTypeCast(config *Config) *ParserTypeCast {
15+
return &ParserTypeCast{utils: NewParserUtils(config), config: config}
16+
}
17+
18+
func (parser *ParserTypeCast) TypeCast(node *pgQuery.Node) *pgQuery.TypeCast {
19+
if node.GetTypeCast() == nil {
20+
return nil
21+
}
22+
23+
typeCast := node.GetTypeCast()
24+
if len(typeCast.TypeName.Names) == 0 {
25+
return nil
26+
}
27+
28+
return typeCast
29+
}
30+
31+
func (parser *ParserTypeCast) TypeName(typeCast *pgQuery.TypeCast) string {
32+
return typeCast.TypeName.Names[0].GetString_().Sval
33+
}
34+
35+
func (parser *ParserTypeCast) ArgStringValue(typeCast *pgQuery.TypeCast) string {
36+
return typeCast.Arg.GetAConst().GetSval().Sval
37+
}
38+
39+
func (parser *ParserTypeCast) MakeCaseTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node {
40+
if existingType := parser.inferNodeType(arg); existingType == typeName {
41+
return arg
42+
}
43+
return parser.utils.MakeTypeCastNode(arg, typeName)
44+
}
45+
46+
func (parser *ParserTypeCast) MakeListValueFromArray(node *pgQuery.Node) *pgQuery.Node {
47+
arrayStr := node.GetAConst().GetSval().Sval
48+
arrayStr = strings.Trim(arrayStr, "{}")
49+
elements := strings.Split(arrayStr, ",")
50+
51+
funcCall := &pgQuery.FuncCall{
52+
Funcname: []*pgQuery.Node{
53+
pgQuery.MakeStrNode("list_value"),
54+
},
55+
}
56+
57+
for _, elem := range elements {
58+
funcCall.Args = append(funcCall.Args,
59+
pgQuery.MakeAConstStrNode(elem, 0))
60+
}
61+
62+
return &pgQuery.Node{
63+
Node: &pgQuery.Node_FuncCall{
64+
FuncCall: funcCall,
65+
},
66+
}
67+
}
68+
69+
func (parser *ParserTypeCast) inferNodeType(node *pgQuery.Node) string {
70+
if typeCast := node.GetTypeCast(); typeCast != nil {
71+
return typeCast.TypeName.Names[0].GetString_().Sval
72+
}
73+
74+
if aConst := node.GetAConst(); aConst != nil {
75+
switch {
76+
case aConst.GetBoolval() != nil:
77+
return "boolean"
78+
case aConst.GetIval() != nil:
79+
return "int8"
80+
case aConst.GetSval() != nil:
81+
return "text"
82+
}
83+
}
84+
return ""
85+
}

src/parser_utils.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ func (utils *ParserUtils) MakeSubselectWithRowsNode(tableName string, tableDef T
7070
}
7171

7272
func (utils *ParserUtils) MakeSubselectWithoutRowsNode(tableName string, tableDef TableDefinition, alias string) *pgQuery.Node {
73-
parserType := NewParserType(utils.config)
7473
columnNodes := make([]*pgQuery.Node, len(tableDef.Columns))
7574
for i, col := range tableDef.Columns {
7675
columnNodes[i] = pgQuery.MakeStrNode(col.Name)
@@ -85,7 +84,7 @@ func (utils *ParserUtils) MakeSubselectWithoutRowsNode(tableName string, tableDe
8584
},
8685
},
8786
}
88-
typedNullNode := parserType.MakeTypeCastNode(nullNode, col.Type)
87+
typedNullNode := utils.MakeTypeCastNode(nullNode, col.Type)
8988
targetList[i] = pgQuery.MakeResTargetNodeWithVal(typedNullNode, 0)
9089
}
9190

@@ -154,8 +153,6 @@ func (utils *ParserUtils) MakeAConstBoolNode(val bool) *pgQuery.Node {
154153
}
155154

156155
func (utils *ParserUtils) makeTypedConstNode(val string, pgType string) *pgQuery.Node {
157-
parserType := NewParserType(utils.config)
158-
159156
if val == "NULL" {
160157
return &pgQuery.Node{
161158
Node: &pgQuery.Node_AConst{
@@ -168,5 +165,21 @@ func (utils *ParserUtils) makeTypedConstNode(val string, pgType string) *pgQuery
168165

169166
constNode := pgQuery.MakeAConstStrNode(val, 0)
170167

171-
return parserType.MakeTypeCastNode(constNode, pgType)
168+
return utils.MakeTypeCastNode(constNode, pgType)
169+
}
170+
171+
func (utils *ParserUtils) MakeTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node {
172+
return &pgQuery.Node{
173+
Node: &pgQuery.Node_TypeCast{
174+
TypeCast: &pgQuery.TypeCast{
175+
Arg: arg,
176+
TypeName: &pgQuery.TypeName{
177+
Names: []*pgQuery.Node{
178+
pgQuery.MakeStrNode(typeName),
179+
},
180+
Location: 0,
181+
},
182+
},
183+
},
184+
}
172185
}

src/query_handler_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,11 @@ func TestHandleQuery(t *testing.T) {
679679
"types": {Uint32ToString(pgtype.Int2OID)},
680680
"values": {"1"},
681681
},
682+
"SELECT 'pg_catalog.array_in'::regproc AS regproc": {
683+
"description": {"regproc"},
684+
"types": {Uint32ToString(pgtype.TextOID)},
685+
"values": {"array_in"},
686+
},
682687

683688
// SELECT * FROM function()
684689
"SELECT * FROM pg_catalog.pg_get_keywords() LIMIT 1": {

0 commit comments

Comments
 (0)