Source File
sql.go
Belonging Package
github.com/jackc/pgx/v4/stdlib
package stdlib
import (
)
var databaseSQLResultFormats pgx.QueryResultFormatsByOID
var pgxDriver *Driver
type ctxKey int
var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func () {
pgxDriver = &Driver{
configs: make(map[string]*pgx.ConnConfig),
}
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
pgtype.BoolOID: 1,
pgtype.ByteaOID: 1,
pgtype.CIDOID: 1,
pgtype.DateOID: 1,
pgtype.Float4OID: 1,
pgtype.Float8OID: 1,
pgtype.Int2OID: 1,
pgtype.Int4OID: 1,
pgtype.Int8OID: 1,
pgtype.OIDOID: 1,
pgtype.TimestampOID: 1,
pgtype.TimestamptzOID: 1,
pgtype.XIDOID: 1,
}
}
var (
fakeTxMutex sync.Mutex
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type OptionOpenDB func(*connector)
func ( func(context.Context, *pgx.Conn) error) OptionOpenDB {
return func( *connector) {
.AfterConnect =
}
}
func ( pgx.ConnConfig, ...OptionOpenDB) *sql.DB {
:= connector{
ConnConfig: ,
AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default
driver: pgxDriver,
}
for , := range {
(&)
}
return sql.OpenDB()
}
type connector struct {
pgx.ConnConfig
AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection
driver *Driver
}
func ( connector) ( context.Context) (driver.Conn, error) {
var (
error
*pgx.Conn
)
if , = pgx.ConnectConfig(, &.ConnConfig); != nil {
return nil,
}
if = .AfterConnect(, ); != nil {
return nil,
}
return &Conn{conn: , driver: .driver, connConfig: .ConnConfig}, nil
}
func () driver.Driver {
return pgxDriver
}
type Driver struct {
configMutex sync.Mutex
configs map[string]*pgx.ConnConfig
}
func ( *Driver) ( string) (driver.Conn, error) {
, := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
defer ()
, := .OpenConnector()
if != nil {
return nil,
}
return .Connect()
}
func ( *Driver) ( string) (driver.Connector, error) {
return &driverConnector{driver: , name: }, nil
}
func ( *Driver) ( *pgx.ConnConfig) string {
.configMutex.Lock()
:= fmt.Sprintf("registeredConnConfig%d", len(.configs))
.configs[] =
.configMutex.Unlock()
return
}
func ( *Driver) ( string) {
.configMutex.Lock()
delete(.configs, )
.configMutex.Unlock()
}
type driverConnector struct {
driver *Driver
name string
}
func ( *driverConnector) ( context.Context) (driver.Conn, error) {
var *pgx.ConnConfig
.driver.configMutex.Lock()
= .driver.configs[.name]
.driver.configMutex.Unlock()
if == nil {
var error
, = pgx.ParseConfig(.name)
if != nil {
return nil,
}
}
, := pgx.ConnectConfig(, )
if != nil {
return nil,
}
:= &Conn{conn: , driver: .driver, connConfig: *}
return , nil
}
func ( *driverConnector) () driver.Driver {
return .driver
}
func ( *pgx.ConnConfig) string {
return pgxDriver.registerConnConfig()
}
func ( string) {
pgxDriver.unregisterConnConfig()
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
connConfig pgx.ConnConfig
}
func ( *Conn) () *pgx.Conn {
return .conn
}
func ( *Conn) ( string) (driver.Stmt, error) {
return .PrepareContext(context.Background(), )
}
func ( *Conn) ( context.Context, string) (driver.Stmt, error) {
if .conn.IsClosed() {
return nil, driver.ErrBadConn
}
:= fmt.Sprintf("pgx_%d", .psCount)
.psCount++
, := .conn.Prepare(, , )
if != nil {
return nil,
}
return &Stmt{sd: , conn: }, nil
}
func ( *Conn) () error {
, := context.WithTimeout(context.Background(), time.Second*5)
defer ()
return .conn.Close()
}
func ( *Conn) () (driver.Tx, error) {
return .BeginTx(context.Background(), driver.TxOptions{})
}
func ( *Conn) ( context.Context, driver.TxOptions) (driver.Tx, error) {
if .conn.IsClosed() {
return nil, driver.ErrBadConn
}
if , := .Value(ctxKeyFakeTx).(**pgx.Conn); {
* = .conn
return fakeTx{}, nil
}
var pgx.TxOptions
switch sql.IsolationLevel(.Isolation) {
case sql.LevelDefault:
case sql.LevelReadUncommitted:
.IsoLevel = pgx.ReadUncommitted
case sql.LevelReadCommitted:
.IsoLevel = pgx.ReadCommitted
case sql.LevelRepeatableRead, sql.LevelSnapshot:
.IsoLevel = pgx.RepeatableRead
case sql.LevelSerializable:
.IsoLevel = pgx.Serializable
default:
return nil, fmt.Errorf("unsupported isolation: %v", .Isolation)
}
if .ReadOnly {
.AccessMode = pgx.ReadOnly
}
, := .conn.BeginTx(, )
if != nil {
return nil,
}
return wrapTx{ctx: , tx: }, nil
}
func ( *Conn) ( context.Context, string, []driver.NamedValue) (driver.Result, error) {
if .conn.IsClosed() {
return nil, driver.ErrBadConn
}
:= namedValueToInterface()
if != nil {
if pgconn.SafeToRetry() {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(.RowsAffected()),
}
func ( *Conn) ( context.Context, string, []driver.NamedValue) (driver.Rows, error) {
if .conn.IsClosed() {
return nil, driver.ErrBadConn
}
:= []interface{}{databaseSQLResultFormats}
= append(, namedValueToInterface()...)
, := .conn.Query(, , ...)
if != nil {
if pgconn.SafeToRetry() {
return nil, driver.ErrBadConn
}
return nil,
}
.Close()
return driver.ErrBadConn
}
return nil
}
return nil
}
func ( *Conn) ( context.Context) error {
if .conn.IsClosed() {
return driver.ErrBadConn
}
return nil
}
type Stmt struct {
sd *pgconn.StatementDescription
conn *Conn
}
func ( *Stmt) () error {
, := context.WithTimeout(context.Background(), time.Second*5)
defer ()
return .conn.conn.Deallocate(, .sd.Name)
}
func ( *Stmt) () int {
return len(.sd.ParamOIDs)
}
func ( *Stmt) ( []driver.Value) (driver.Result, error) {
return nil, errors.New("Stmt.Exec deprecated and not implemented")
}
func ( *Stmt) ( context.Context, []driver.NamedValue) (driver.Result, error) {
return .conn.ExecContext(, .sd.Name, )
}
func ( *Stmt) ( []driver.Value) (driver.Rows, error) {
return nil, errors.New("Stmt.Query deprecated and not implemented")
}
func ( *Stmt) ( context.Context, []driver.NamedValue) (driver.Rows, error) {
return .conn.QueryContext(, .sd.Name, )
}
type rowValueFunc func(src []byte) (driver.Value, error)
type Rows struct {
conn *Conn
rows pgx.Rows
valueFuncs []rowValueFunc
skipNext bool
skipNextMore bool
columnNames []string
}
func ( *Rows) () []string {
if .columnNames == nil {
:= .rows.FieldDescriptions()
.columnNames = make([]string, len())
for , := range {
.columnNames[] = string(.Name)
}
}
return .columnNames
}
func ( *Rows) ( int) string {
if , := .conn.conn.ConnInfo().DataTypeForOID(.rows.FieldDescriptions()[].DataTypeOID); {
return strings.ToUpper(.Name)
}
return strconv.FormatInt(int64(.rows.FieldDescriptions()[].DataTypeOID), 10)
}
const varHeaderSize = 4
func ( *Rows) ( int) (int64, bool) {
:= .rows.FieldDescriptions()[]
switch .DataTypeOID {
case pgtype.TextOID, pgtype.ByteaOID:
return math.MaxInt64, true
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
return int64(.TypeModifier - varHeaderSize), true
default:
return 0, false
}
}
func ( *Rows) ( int) (, int64, bool) {
:= .rows.FieldDescriptions()[]
switch .DataTypeOID {
case pgtype.NumericOID:
:= .TypeModifier - varHeaderSize
= int64(( >> 16) & 0xffff)
= int64( & 0xffff)
return , , true
default:
return 0, 0, false
}
}
func ( *Rows) ( int) reflect.Type {
:= .rows.FieldDescriptions()[]
switch .DataTypeOID {
case pgtype.Float8OID:
return reflect.TypeOf(float64(0))
case pgtype.Float4OID:
return reflect.TypeOf(float32(0))
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
case pgtype.BoolOID:
return reflect.TypeOf(false)
case pgtype.NumericOID:
return reflect.TypeOf(float64(0))
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
return reflect.TypeOf(time.Time{})
case pgtype.ByteaOID:
return reflect.TypeOf([]byte(nil))
default:
return reflect.TypeOf("")
}
}
func ( *Rows) () error {
.rows.Close()
return .rows.Err()
}
func ( *Rows) ( []driver.Value) error {
:= .conn.conn.ConnInfo()
:= .rows.FieldDescriptions()
if .valueFuncs == nil {
.valueFuncs = make([]rowValueFunc, len())
for , := range {
:= .DataTypeOID
:= .Format
switch .DataTypeOID {
case pgtype.BoolOID:
var bool
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return ,
}
case pgtype.ByteaOID:
var []byte
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return ,
}
case pgtype.CIDOID:
var pgtype.CID
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.DateOID:
var pgtype.Date
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.Float4OID:
var float32
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return float64(),
}
case pgtype.Float8OID:
var float64
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return ,
}
case pgtype.Int2OID:
var int16
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return int64(),
}
case pgtype.Int4OID:
var int32
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return int64(),
}
case pgtype.Int8OID:
var int64
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return ,
}
case pgtype.JSONOID:
var pgtype.JSON
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.JSONBOID:
var pgtype.JSONB
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.OIDOID:
var pgtype.OIDValue
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.TimestampOID:
var pgtype.Timestamp
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.TimestamptzOID:
var pgtype.Timestamptz
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
case pgtype.XIDOID:
var pgtype.XID
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
if != nil {
return nil,
}
return .Value()
}
default:
var string
:= .PlanScan(, , &)
.valueFuncs[] = func( []byte) (driver.Value, error) {
:= .Scan(, , , , &)
return ,
}
}
}
}
var bool
if .skipNext {
= .skipNextMore
.skipNext = false
} else {
= .rows.Next()
}
if ! {
if .rows.Err() == nil {
return io.EOF
} else {
return .rows.Err()
}
}
for , := range .rows.RawValues() {
if != nil {
var error
[], = .valueFuncs[]()
if != nil {
return fmt.Errorf("convert field %d failed: %v", , )
}
} else {
[] = nil
}
}
return nil
}
func ( []driver.Value) []interface{} {
:= make([]interface{}, 0, len())
for , := range {
if != nil {
= append(, .(interface{}))
} else {
= append(, nil)
}
}
return
}
func ( []driver.NamedValue) []interface{} {
:= make([]interface{}, 0, len())
for , := range {
if .Value != nil {
= append(, .Value.(interface{}))
} else {
= append(, nil)
}
}
return
}
type wrapTx struct {
ctx context.Context
tx pgx.Tx
}
func ( wrapTx) () error { return .tx.Commit(.ctx) }
func ( wrapTx) () error { return .tx.Rollback(.ctx) }
type fakeTx struct{}
func (fakeTx) () error { return nil }
func (fakeTx) () error { return nil }
func ( *sql.DB, *pgx.Conn) error {
var *sql.Tx
var bool
if .PgConn().IsBusy() || .PgConn().TxStatus() != 'I' {
, := context.WithTimeout(context.Background(), time.Second)
defer ()
.Close()
}
fakeTxMutex.Lock()
, = fakeTxConns[]
if {
delete(fakeTxConns, )
fakeTxMutex.Unlock()
} else {
fakeTxMutex.Unlock()
return fmt.Errorf("can't release conn that is not acquired")
}
return .Rollback()
![]() |
The pages are generated with Golds v0.3.2-preview. (GOOS=darwin GOARCH=amd64) Golds is a Go 101 project developed by Tapir Liu. PR and bug reports are welcome and can be submitted to the issue list. Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds. |