go自动化生成数据库curd代码(三):ANTLR解析SQL
在上一节我们了解了go的抽象语法树AST,并利用go提供的AST包拿到了用户定义的sql。接下来就是如何解析sql,将sql语句中的表名,列字段的名称,类型等关键信息提取出来。 这就需要我们的语法分析了,在本项目中我们决定采用ANTLR来完成此任务,他是一个强大的工具,下文我将为大家介绍是如何实现的。
ANTLR
简介
再讲ANTLR之前,还是想先提一下yacc。yacc是比较出名的语法分析器,不过年代久远,诞生于上个世纪70年代,yacc需要与lex一起才能实现完整的语法树构建。 lex是词法分析器,用于分割语句中的词块,也就是token。go官方也提供了goyacc给我们使用,网上也有关于yacc解析sql的源码。 不过我们还是选择了使用更多的ANTLR,ANTLR目前仍在维护,实现起来比较简单,开发快,还支持所有主流语言,还提供了可视化的语法树,debug特别方便。
安装
安装ANTLR有两种方式,最简单的是用pip3安装。因为我本机有python3,所以很方便。
$ pip install antlr4-tools
执行命令
$ antlr4
Downloading antlr4-4.11.1-complete.jar
ANTLR tool needs Java to run; install Java JRE 11 yes/no (default yes)? y
Installed Java in /Users/parrt/.jre/jdk-11.0.15+10-jre; remove that dir to uninstall
ANTLR Parser Generator Version 4.11.1
-o ___ specify output directory where all output is generated
-lib ___ specify location of grammars, tokens files
...
如果上面的命令都没问题,就是安装成功了,我们可以尝试下,比如实现一个计算器。 先创建Expr.g4,文件名必须与grammar相对应
grammar Expr;
prog: expr EOF ;
expr: expr ('*'|'/') expr
| expr ('+'|'-') expr
| INT
| '(' expr ')'
;
NEWLINE : [\r\n]+ -> skip;
INT : [0-9]+ ;
并使用强大的gui功能(语法树)
antlr4-parse Expr.g4 prog -gui
解析SQL
在编写规则的时候,本来是花了几天时间去实现,完成了表名以及id的定义,不过最后还是发现单单一个建表语句就有很多的规则。 如果单靠自己实现,可能会覆盖不全,而且我平时上班,可能需要花一个月的时间,写这个对我来说帮助也不是很大。 所以,我参照了ANTLR官方mysql的语法(ANTLR官方提供了大量的例子,有兴趣的可以去看看),稍微改造了下,只留下了建表的语法,其余的全部被我删除。 不过,lexer那里还是全部保留下来,虽然有很多token没有使用,考虑涉及到关键字的匹配分词,我都没删。 官方提供的语法虽然很牛逼,不过还是有好几个bug(有些规则为了复用,导致一些根本不会出现在建表规则的也匹配到了),不过这倒不影响,我们的功能要求是能解析,你只要能把正确的解析出来就行。 但是这里也不是说直接拷贝过来就完事,还是要考虑几个问题,解析是不支持多条语句的,如果多个表定义多个变量,分多次解析就行,表名也要支持db.tbl这种情况,mysql字段类型go中类型的转化问题,这些问题我都交给了运行时去解决。
运行时解析
先定义我们解析的结果
type ColumnDecl struct {
Decl string // sql字段定义,用于debug
Name string // 字段名称
Comment string // 字段描述
SqlType string // mysql中的类型
GoType GoType // go中对应的类型
IsNotNull bool // 是否可以为空(sqlx有Null类型)
}
// 索引(用于生成curd代码的查询条件)
type ColumnIndex struct {
Decl string
Columns []ColumnDecl
}
type TableAttr struct {
TableName string // 表名
Columns []ColumnDecl // 字段
PrimaryKey ColumnIndex // 主键
UniqueKeys []ColumnIndex // 唯一键
}
GoType的定义
type GoType string
const (
Invalid = "invalid"
Bool = "bool"
Int8 = "int8"
Int16 = "int16"
Int32 = "int32"
Int64 = "int64"
Uint8 = "uint8"
Uint16 = "uint16"
Uint32 = "uint32"
Uint64 = "uint64"
Float32 = "float32"
Float64 = "float64"
String = "string"
Time = "time.Time"
SliceByte = "[]byte"
SliceUint8 = "[]uint8"
)
定义好解析结果后,我们先用ANTLR生成代码
antlr4 -Dlanguage=Go *.g4
生成之后,我们实现自己的listener
type StmtListener struct {
*BaseStmtParserListener
column ColumnDecl
TableAttr
}
func NewStmtListener() *StmtListener {
return new(StmtListener)
}
代码比较长,这里以提取表名为例子
func (l *StmtListener) EnterTableName(ctx *TableNameContext) {
var tableName string
switch ctx.GetStop().GetTokenType() {
// 需要去掉引号
case StmtParserREVERSE_QUOTE_ID, StmtParserCHARSET_REVERSE_QOUTE_STRING, StmtParserSTRING_LITERAL:
name := ctx.GetStop().GetText()
if len(name) <= 2 {
return
}
tableName = name[1 : len(name)-1]
// db.tbl的形式
case StmtParserDOT_ID:
name := ctx.GetStop().GetText()
if len(name) <= 1 {
return
}
tableName = name[1:]
default:
tableName = ctx.GetText()
}
l.TableName = tableName
}
除了解析之外,我们还需要对错误进行处理,不然错误发生我们都还不知道,无法判断SQL是否正确
type ErrorListener struct {
*antlr.DefaultErrorListener
errors []error
}
func NewErrorListener() *ErrorListener {
return new(ErrorListener)
}
func (l *ErrorListener) HasError() bool {
return len(l.errors) > 0
}
func (l *ErrorListener) Errors() []error {
return l.errors
}
func (l *ErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
p := recognizer.(antlr.Parser)
stack := p.GetRuleInvocationStack(p.GetParserRuleContext())
err := fmt.Errorf("rule: %v line %d: %d at %v : %s", stack[0], line, column, offendingSymbol, msg)
l.errors = append(l.errors, err)
}
随后便将上面两个集成在一起使用
import (
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
parser "github.com/lemon-1997/sqlboy/antlr"
)
func parseStmt(ddl string) (parser.TableAttr, []error) {
input := antlr.NewInputStream(ddl)
lexer := parser.NewStmtLexer(input)
stream := antlr.NewCommonTokenStream(lexer, 0)
p := parser.NewStmtParser(stream)
el := parser.NewErrorListener()
p.RemoveErrorListeners()
p.AddErrorListener(el)
p.BuildParseTrees = true
tree := p.Prog()
if el.HasError() {
return parser.TableAttr{}, el.Errors()
}
l := parser.NewStmtListener()
antlr.ParseTreeWalkerDefault.Walk(l, tree)
return l.TableAttr, nil
}
在实现代码过程中,还发现了ANTLR go runtime包的一个错误,并提了个pr https://github.com/antlr/antlr4/pull/3999
小结
好了,到这里我们已经能够正确把SQL解析,并提取出我们想要的表字段等信息,有了这些信息后,我们就可以根据表的结构,去生成相应的代码了。
下一节我将向大家介绍如果用模板渲染出代码,有兴趣的可以关注一下。
项目源码:https://github.com/lemon-1997/sqlboy