Copyright 2019 The Go Authors. All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
Package database adds some useful functionality to a sql.DB. It is independent of the database driver and the DB schema.
package database

import (
	
	
	
	
	
	
	
	

	
	
	
	
)
DB wraps a sql.DB. The methods it exports correspond closely to those of sql.DB. They enhance the original by requiring a context argument, and by logging the query and any resulting errors. A DB may represent a transaction. If so, its execution and query methods operate within the transaction.
type DB struct {
	db         *sql.DB
	instanceID string
	tx         *sql.Tx
	conn       *sql.Conn     // the Conn of the Tx, when tx != nil
	opts       sql.TxOptions // valid when tx != nil
	mu         sync.Mutex
	maxRetries int // max times a single transaction was retried
}
Open creates a new DB for the given connection string.
func (, ,  string) ( *DB,  error) {
	defer derrors.Wrap(&, "database.Open(%q, %q)",
		, redactPassword())

	,  := sql.Open(, )
	if  != nil {
		return nil, 
	}

	,  := context.WithTimeout(context.Background(), 30*time.Second)
	defer ()
	if  := .PingContext();  != nil {
		return nil, 
	}
	return New(, ), nil
}
New creates a new DB from a sql.DB.
func ( *sql.DB,  string) *DB {
	return &DB{db: , instanceID: }
}

func ( *DB) () error {
	return .db.Ping()
}

func ( *DB) () bool {
	return .tx != nil
}

func ( *DB) () bool {
	return .tx != nil && isRetryable(.opts.Isolation)
}

var passwordRegexp = regexp.MustCompile(`password=\S+`)

func ( string) string {
	return passwordRegexp.ReplaceAllLiteralString(, "password=REDACTED")
}
Close closes the database connection.
func ( *DB) () error {
	return .db.Close()
}
Exec executes a SQL statement and returns the number of rows it affected.
func ( *DB) ( context.Context,  string,  ...interface{}) ( int64,  error) {
	defer logQuery(, , , .instanceID, .IsRetryable())(&)
	,  := .execResult(, , ...)
	if  != nil {
		return 0, 
	}
	,  := .RowsAffected()
	if  != nil {
		return 0, fmt.Errorf("RowsAffected: %v", )
	}
	return , nil
}
execResult executes a SQL statement and returns a sql.Result.
func ( *DB) ( context.Context,  string,  ...interface{}) ( sql.Result,  error) {
	if .tx != nil {
		return .tx.ExecContext(, , ...)
	}
	return .db.ExecContext(, , ...)
}
Query runs the DB query.
func ( *DB) ( context.Context,  string,  ...interface{}) ( *sql.Rows,  error) {
	defer logQuery(, , , .instanceID, .IsRetryable())(&)
	if .tx != nil {
		return .tx.QueryContext(, , ...)
	}
	return .db.QueryContext(, , ...)
}
QueryRow runs the query and returns a single row.
func ( *DB) ( context.Context,  string,  ...interface{}) *sql.Row {
	defer logQuery(, , , .instanceID, .IsRetryable())(nil)
	 := time.Now()
	defer func() {
		if .Err() != nil {
			,  := .Deadline()
			 := fmt.Sprintf("args=%v; elapsed=%q, start=%q, deadline=%q", , time.Since(), , )
			log.Errorf(, "QueryRow context error: %v "+, .Err())
		}
	}()
	if .tx != nil {
		return .tx.QueryRowContext(, , ...)
	}
	return .db.QueryRowContext(, , ...)
}

func ( *DB) ( context.Context,  string) (*sql.Stmt, error) {
	defer logQuery(, "preparing "+, nil, .instanceID, .IsRetryable())
	if .tx != nil {
		return .tx.PrepareContext(, )
	}
	return .db.PrepareContext(, )
}
RunQuery executes query, then calls f on each row.
func ( *DB) ( context.Context,  string,  func(*sql.Rows) error,  ...interface{}) error {
	,  := .Query(, , ...)
	if  != nil {
		return 
	}
	return processRows(, )
}

func ( *sql.Rows,  func(*sql.Rows) error) error {
	defer .Close()
	for .Next() {
		if  := ();  != nil {
			return 
		}
	}
	return .Err()
}
Transact executes the given function in the context of a SQL transaction at the given isolation level, rolling back the transaction if the function panics or returns an error. The given function is called with a DB that is associated with a transaction. The DB should be used only inside the function; if it is used to access the database after the function returns, the calls will return errors. If the isolation level requires it, Transact will retry the transaction upon serialization failure, so txFunc may be called more than once.
func ( *DB) ( context.Context,  sql.IsolationLevel,  func(*DB) error) ( error) {
For the levels which require retry, see https://www.postgresql.org/docs/11/transaction-iso.html.
	 := &sql.TxOptions{Isolation: }
	if isRetryable() {
		return .transactWithRetry(, , )
	}
	return .transact(, , )
}

func ( sql.IsolationLevel) bool {
	return  == sql.LevelRepeatableRead ||  == sql.LevelSerializable
}
serializationFailureCode is the Postgres error code returned when a serializable transaction fails because it would violate serializability. See https://www.postgresql.org/docs/current/errcodes-appendix.html.
const serializationFailureCode = "40001"

func ( *DB) ( context.Context,  *sql.TxOptions,  func(*DB) error) ( error) {
Retry on serialization failure, up to some max. See https://www.postgresql.org/docs/11/transaction-iso.html.
	const  = 30
	for  := 0;  <= ; ++ {
		 = .transact(, , )
		if isSerializationFailure() {
			.mu.Lock()
			if  > .maxRetries {
				.maxRetries = 
			}
			.mu.Unlock()
			log.Debugf(, "serialization failure; retrying")
			continue
		}
		if  != nil {
			log.Debugf(, "transactWithRetry: error type %T: %[1]v", )
			if strings.Contains(.Error(), serializationFailureCode) {
				return fmt.Errorf("error text has %q but not recognized as serialization failure: type %T, err %v",
					serializationFailureCode, , )
			}
		}
		if  > 0 {
			log.Debugf(, "retried serializable transaction %d time(s)", )
		}
		return 
	}
	return fmt.Errorf("reached max number of tries due to serialization failure (%d)", )
}

The underlying error type depends on the driver. Try both pq and pgx types.
	var  *pq.Error
	if errors.As(, &) && .Code == serializationFailureCode {
		return true
	}
	var  *pgconn.PgError
	if errors.As(, &) && .Code == serializationFailureCode {
		return true
	}
	return false
}

func ( *DB) ( context.Context,  *sql.TxOptions,  func(*DB) error) ( error) {
	if .InTransaction() {
		return errors.New("a DB Transact function was called on a DB already in a transaction")
	}
	,  := .db.Conn()
	if  != nil {
		return 
	}
	defer .Close()

	,  := .BeginTx(, )
	if  != nil {
		return fmt.Errorf("conn.BeginTx(): %w", )
	}
	defer func() {
		if  := recover();  != nil {
			.Rollback()
			panic()
		} else if  != nil {
			.Rollback()
		} else {
			if  := .Commit();  != nil {
				 = fmt.Errorf("tx.Commit(): %w", )
			}
		}
	}()

	 := New(.db, .instanceID)
	.tx = 
	.conn = 
	.opts = *
	defer .logTransaction()(&)
	if  := ();  != nil {
		return fmt.Errorf("txFunc(tx): %w", )
	}
	return nil
}
MaxRetries returns the maximum number of times thata serializable transaction was retried.
func ( *DB) () int {
	.mu.Lock()
	defer .mu.Unlock()
	return .maxRetries
}

const OnConflictDoNothing = "ON CONFLICT DO NOTHING"
BulkInsert constructs and executes a multi-value insert statement. The query is constructed using the format: INSERT INTO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) If conflictAction is not empty, it is appended to the statement. The query is executed using a PREPARE statement with the provided values.
func ( *DB) ( context.Context,  string,  []string,  []interface{},  string) ( error) {
	defer derrors.Wrap(&, "DB.BulkInsert(ctx, %q, %v, [%d values], %q)",
		, , len(), )

	return .bulkInsert(, , , nil, , , nil)
}
BulkInsertReturning is like BulkInsert, but supports returning values from the INSERT statement. In addition to the arguments of BulkInsert, it takes a list of columns to return and a function to scan those columns. To get the returned values, provide a function that scans them as if they were the selected columns of a query. See TestBulkInsert for an example.
func ( *DB) ( context.Context,  string,  []string,  []interface{},  string,  []string,  func(*sql.Rows) error) ( error) {
	defer derrors.Wrap(&, "DB.BulkInsertReturning(ctx, %q, %v, [%d values], %q, %v, scanFunc)",
		, , len(), , )

	if  == nil ||  == nil {
		return errors.New("need returningColumns and scan function")
	}
	return .bulkInsert(, , , , , , )
}
BulkUpsert is like BulkInsert, but instead of a conflict action, a list of conflicting columns is provided. An "ON CONFLICT (conflict_columns) DO UPDATE" clause is added to the statement, with assignments "c=excluded.c" for every column c.
func ( *DB) ( context.Context,  string,  []string,  []interface{},  []string) error {
	 := buildUpsertConflictAction(, )
	return .BulkInsert(, , , , )
}
BulkUpsertReturning is like BulkInsertReturning, but performs an upsert like BulkUpsert.
func ( *DB) ( context.Context,  string,  []string,  []interface{}, ,  []string,  func(*sql.Rows) error) error {
	 := buildUpsertConflictAction(, )
	return .BulkInsertReturning(, , , , , , )
}

func ( *DB) ( context.Context,  string, ,  []string,  []interface{},  string,  func(*sql.Rows) error) ( error) {
	if  := len() % len();  != 0 {
		return fmt.Errorf("modulus of len(values) and len(columns) must be 0: got %d", )
	}
Postgres supports up to 65535 parameters, but stop well before that so we don't construct humongous queries.
	const  = 1000
	 := ( / len()) * len()
This is a pathological case (len(columns) > maxParameters), but we handle it cautiously.
		return fmt.Errorf("too many columns to insert: %d", len())
	}

	 := func( int) (*sql.Stmt, error) {
		return .Prepare(, buildInsertQuery(, , , , ))
	}

	var  *sql.Stmt
	for  := 0;  < len();  +=  {
		 :=  + 
		if  <= len() &&  == nil {
			,  = ()
			if  != nil {
				return 
			}
			defer .Close()
		} else if  > len() {
			 = len()
			,  = ( - )
			if  != nil {
				return 
			}
			defer .Close()
		}
		 := [:]
		var  error
		if  == nil {
			_,  = .ExecContext(, ...)
		} else {
			var  *sql.Rows
			,  = .QueryContext(, ...)
			if  != nil {
				return 
			}
			 = processRows(, )
		}
		if  != nil {
			return fmt.Errorf("running bulk insert query, values[%d:%d]): %w", , , )
		}
	}
	return nil
}
buildInsertQuery builds an multi-value insert query, following the format: INSERT TO <table> (<columns>) VALUES (<placeholders-for-each-item-in-values>) <conflictAction> If returningColumns is not empty, it appends a RETURNING clause to the query. When calling buildInsertQuery, it must be true that nvalues % len(columns) == 0.
func ( string, ,  []string,  int,  string) string {
	var  strings.Builder
	fmt.Fprintf(&, "INSERT INTO %s", )
	fmt.Fprintf(&, "(%s) VALUES", strings.Join(, ", "))

	var  []string
Construct the full query by adding placeholders for each set of values that we want to insert.
		 = append(, fmt.Sprintf("$%d", ))
		if %len() != 0 {
			continue
		}
When the end of a set is reached, write it to the query builder and reset placeholders.
		fmt.Fprintf(&, "(%s)", strings.Join(, ", "))
		 = nil
Do not add a comma delimiter after the last set of values.
		if  ==  {
			break
		}
		.WriteString(", ")
	}
	if  != "" {
		.WriteString(" " + )
	}
	if len() > 0 {
		fmt.Fprintf(&, " RETURNING %s", strings.Join(, ", "))
	}
	return .String()
}

func (,  []string) string {
	var  []string
	for ,  := range  {
		 = append(, fmt.Sprintf("%s=excluded.%[1]s", ))
	}
	return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s",
		strings.Join(, ", "),
		strings.Join(, ", "))
}
maxBulkUpdateArrayLen is the maximum size of an array that BulkUpdate will send to Postgres. (Postgres has no size limit on arrays, but we want to keep the statements to a reasonable size.) It is a variable for testing.
BulkUpdate executes multiple UPDATE statements in a transaction. Columns must contain the names of some of table's columns. The first is treated as a key; that is, the values to update are matched with existing rows by comparing the values of the first column. Types holds the database type of each column. For example, []string{"INT", "TEXT"} Values contains one slice of values per column. (Note that this is unlike BulkInsert, which takes a single slice of interleaved values.)
func ( *DB) ( context.Context,  string, ,  []string,  [][]interface{}) ( error) {
	defer derrors.Wrap(&, "DB.BulkUpdate(ctx, tx, %q, %v, [%d values])",
		, , len())

	if len() < 2 {
		return errors.New("need at least two columns")
	}
	if len() != len() {
		return errors.New("len(values) != len(columns)")
	}
	 := len([0])
	for ,  := range [1:] {
		if len() !=  {
			return errors.New("all values slices must be the same length")
		}
	}
	 := buildBulkUpdateQuery(, , )
	for  := 0;  < ;  += maxBulkUpdateArrayLen {
		 :=  + maxBulkUpdateArrayLen
		if  >  {
			 = 
		}
		var  []interface{}
		for ,  := range  {
			 = append(, pq.Array([:]))
		}
		if ,  := .Exec(, , ...);  != nil {
			return fmt.Errorf("db.Exec(%q, values[%d:%d]): %w", , , , )
		}
	}
	return nil
}

func ( string, ,  []string) string {
Build "c = data.c" for each non-key column.
	for ,  := range [1:] {
		 = append(, fmt.Sprintf("%s = data.%[1]s", ))
Build "UNNEST($1::TYPE) AS c" for each column. We need the type, or Postgres complains that UNNEST is not unique.
	for ,  := range  {
		 = append(, fmt.Sprintf("UNNEST($%d::%s[]) AS %s", +1, [], ))
	}
	return fmt.Sprintf(`
		UPDATE %[1]s
		SET %[2]s
		FROM (SELECT %[3]s) AS data
		WHERE %[1]s.%[4]s = data.%[4]s`,
		,                       // 1
		strings.Join(, ", "),    // 2
		strings.Join(, ", "), // 3
		[0],                  // 4
	)
}
emptyStringScanner wraps the functionality of sql.NullString to just write an empty string if the value is NULL.
type emptyStringScanner struct {
	ptr *string
}

func ( emptyStringScanner) ( interface{}) error {
	var  sql.NullString
	if  := .Scan();  != nil {
		return 
	}
	*.ptr = .String
	return nil
}
NullIsEmpty returns a sql.Scanner that writes the empty string to s if the sql.Value is NULL.