type Part interface {}
type Query struct {
Parts []Part
}
func (q *Query ) Sanitize (args ...interface {}) (string , error ) {
argUse := make ([]bool , len (args ))
buf := &bytes .Buffer {}
for _ , part := range q .Parts {
var str string
switch part := part .(type ) {
case string :
str = part
case int :
argIdx := part - 1
if argIdx >= len (args ) {
return "" , fmt .Errorf ("insufficient arguments" )
}
arg := args [argIdx ]
switch arg := arg .(type ) {
case nil :
str = "null"
case int64 :
str = strconv .FormatInt (arg , 10 )
case float64 :
str = strconv .FormatFloat (arg , 'f' , -1 , 64 )
case bool :
str = strconv .FormatBool (arg )
case []byte :
str = QuoteBytes (arg )
case string :
str = QuoteString (arg )
case time .Time :
str = arg .Truncate (time .Microsecond ).Format ("'2006-01-02 15:04:05.999999999Z07:00:00'" )
default :
return "" , fmt .Errorf ("invalid arg type: %T" , arg )
}
argUse [argIdx ] = true
default :
return "" , fmt .Errorf ("invalid Part type: %T" , part )
}
buf .WriteString (str )
}
for i , used := range argUse {
if !used {
return "" , fmt .Errorf ("unused argument: %d" , i )
}
}
return buf .String (), nil
}
func NewQuery (sql string ) (*Query , error ) {
l := &sqlLexer {
src : sql ,
stateFn : rawState ,
}
for l .stateFn != nil {
l .stateFn = l .stateFn (l )
}
query := &Query {Parts : l .parts }
return query , nil
}
func QuoteString (str string ) string {
return "'" + strings .ReplaceAll (str , "'" , "''" ) + "'"
}
func QuoteBytes (buf []byte ) string {
return `'\x` + hex .EncodeToString (buf ) + "'"
}
type sqlLexer struct {
src string
start int
pos int
nested int
stateFn stateFn
parts []Part
}
type stateFn func (*sqlLexer ) stateFn
func rawState (l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case 'e' , 'E' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '\'' {
l .pos += width
return escapeStringState
}
case '\'' :
return singleQuoteState
case '"' :
return doubleQuoteState
case '$' :
nextRune , _ := utf8 .DecodeRuneInString (l .src [l .pos :])
if '0' <= nextRune && nextRune <= '9' {
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos -width ])
}
l .start = l .pos
return placeholderState
}
case '-' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '-' {
l .pos += width
return oneLineCommentState
}
case '/' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune == '*' {
l .pos += width
return multilineCommentState
}
case utf8 .RuneError :
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
func singleQuoteState (l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '\'' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune != '\'' {
return rawState
}
l .pos += width
case utf8 .RuneError :
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}
func doubleQuoteState (l *sqlLexer ) stateFn {
for {
r , width := utf8 .DecodeRuneInString (l .src [l .pos :])
l .pos += width
switch r {
case '"' :
nextRune , width := utf8 .DecodeRuneInString (l .src [l .pos :])
if nextRune != '"' {
return rawState
}
l .pos += width
case utf8 .RuneError :
if l .pos -l .start > 0 {
l .parts = append (l .parts , l .src [l .start :l .pos ])
l .start = l .pos
}
return nil
}
}
}