Source File
conn.go
Belonging Package
github.com/lib/pq
package pq
import (
)
var (
ErrNotSupported = errors.New("pq: Unsupported command")
ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
errUnexpectedReady = errors.New("unexpected ReadyForQuery")
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
)
type Driver struct{}
currentLocation *time.Location
}
type transactionStatus byte
const (
txnStatusIdle transactionStatus = 'I'
txnStatusIdleInTransaction transactionStatus = 'T'
txnStatusInFailedTransaction transactionStatus = 'E'
)
func ( transactionStatus) () string {
switch {
case txnStatusIdle:
return "idle"
case txnStatusIdleInTransaction:
return "idle in transaction"
case txnStatusInFailedTransaction:
return "in a failed transaction"
default:
errorf("unknown transactionStatus %d", )
}
panic("not reached")
}
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type defaultDialer struct {
d net.Dialer
}
func ( defaultDialer) (, string) (net.Conn, error) {
return .d.Dial(, )
}
func ( defaultDialer) (, string, time.Duration) (net.Conn, error) {
, := context.WithTimeout(context.Background(), )
defer ()
return .DialContext(, , )
}
func ( defaultDialer) ( context.Context, , string) (net.Conn, error) {
return .d.DialContext(, , )
}
type conn struct {
c net.Conn
buf *bufio.Reader
namei int
scratch [512]byte
txnStatus transactionStatus
txnFinish func()
func ( *conn) ( values) ( error) {
:= func( string, *bool) error {
if , := []; {
if == "yes" {
* = true
} else if == "no" {
* = false
} else {
return fmt.Errorf("unrecognized value %q for %s", , )
}
}
return nil
}
= ("disable_prepared_binary_result", &.disablePreparedBinaryResult)
if != nil {
return
}
return ("binary_parameters", &.binaryParameters)
}
:= func( string) []string {
:= make([]string, 0, 5)
:= make([]rune, 0, len())
var bool
for , := range {
switch {
case :
= append(, )
= false
case == '\\':
= true
case == ':':
= append(, string())
= [:0]
default:
= append(, )
}
}
return append(, string())
}
for .Scan() {
:= .Text()
if len() == 0 || [0] == '#' {
continue
}
:= ()
if len() != 5 {
continue
}
if ([0] == "*" || [0] == || ([0] == "localhost" && ( == "" || == "unix"))) && ([1] == "*" || [1] == ) && ([2] == "*" || [2] == ) && ([3] == "*" || [3] == ) {
["password"] = [4]
return
}
}
}
func ( *conn) ( byte) *writeBuf {
.scratch[0] =
return &writeBuf{
buf: .scratch[:5],
pos: 1,
}
}
if == "unix" {
["sslmode"] = "disable"
}
:= time.Now().Add()
var net.Conn
if , := .(DialerContext); {
, := context.WithTimeout(, )
defer ()
, = .DialContext(, , )
} else {
, = .DialTimeout(, , )
}
if != nil {
return nil,
}
= .SetDeadline()
return ,
}
if , := .(DialerContext); {
return .DialContext(, , )
}
return .Dial(, )
}
func ( values) (string, string) {
:= ["host"]
if strings.HasPrefix(, "/") {
:= path.Join(, ".s.PGSQL."+["port"])
return "unix",
}
return "tcp", net.JoinHostPort(, ["port"])
}
type values map[string]string
func ( string, values) error {
:= newScanner()
for {
var (
, []rune
rune
bool
)
if , = .SkipSpaces(); ! {
break
}
if != '=' {
, = .SkipSpaces()
}
[string()] = ""
break
}
if != '\'' {
for !unicode.IsSpace() {
if == '\\' {
if , = .Next(); ! {
return fmt.Errorf(`missing character after backslash`)
}
}
= append(, )
if , = .Next(); ! {
break
}
}
} else {
:
for {
if , = .Next(); ! {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch {
case '\'':
break
case '\\':
, _ = .Next()
fallthrough
default:
= append(, )
}
}
}
[string()] = string()
}
return nil
}
func ( *conn) () bool {
return .txnStatus == txnStatusIdleInTransaction ||
.txnStatus == txnStatusInFailedTransaction
}
func ( *conn) ( bool) {
if .isInTransaction() != {
.bad = true
errorf("unexpected transaction status %v", .txnStatus)
}
}
func ( *conn) () ( driver.Tx, error) {
return .begin("")
}
func ( *conn) ( string) ( driver.Tx, error) {
if .bad {
return nil, driver.ErrBadConn
}
defer .errRecover(&)
.checkIsInTransaction(false)
, , := .simpleExec("BEGIN" + )
if != nil {
return nil,
}
if != "BEGIN" {
.bad = true
return nil, fmt.Errorf("unexpected command tag %s", )
}
if .txnStatus != txnStatusIdleInTransaction {
.bad = true
return nil, fmt.Errorf("unexpected transaction status %v", .txnStatus)
}
return , nil
}
func ( *conn) () {
if := .txnFinish; != nil {
()
}
}
func ( *conn) () ( error) {
defer .closeTxn()
if .bad {
return driver.ErrBadConn
}
defer .errRecover(&)
if .txnStatus == txnStatusInFailedTransaction {
if := .rollback(); != nil {
return
}
return ErrInFailedTransaction
}
, , := .simpleExec("COMMIT")
if != nil {
if .isInTransaction() {
.bad = true
}
return
}
if != "COMMIT" {
.bad = true
return fmt.Errorf("unexpected command tag %s", )
}
.checkIsInTransaction(false)
return nil
}
func ( *conn) () ( error) {
defer .closeTxn()
if .bad {
return driver.ErrBadConn
}
defer .errRecover(&)
return .rollback()
}
func ( *conn) () ( error) {
.checkIsInTransaction(true)
, , := .simpleExec("ROLLBACK")
if != nil {
if .isInTransaction() {
.bad = true
}
return
}
if != "ROLLBACK" {
return fmt.Errorf("unexpected command tag %s", )
}
.checkIsInTransaction(false)
return nil
}
func ( *conn) () string {
.namei++
return strconv.FormatInt(int64(.namei), 10)
}
func ( *conn) ( string) ( driver.Result, string, error) {
:= .writeBuf('Q')
.string()
.send()
for {
, := .recv1()
switch {
case 'C':
, = .parseComplete(.string())
case 'Z':
.processReadyForQuery()
if == nil && == nil {
= errUnexpectedReady
return
case 'E':
= parseError()
case 'I':
= emptyRows
.saveMessage(, )
return
= &rows{cn: }
.rowsHeader = parsePortalRowDescribe()
case oid.T_bytea:
fallthrough
case oid.T_int8:
fallthrough
case oid.T_int4:
fallthrough
case oid.T_int2:
fallthrough
case oid.T_uuid:
[] = formatBinary
= false
default:
= false
}
}
if {
return , colFmtDataAllBinary
} else if {
return , colFmtDataAllText
} else {
= make([]byte, 2+len()*2)
binary.BigEndian.PutUint16(, uint16(len()))
for , := range {
binary.BigEndian.PutUint16([2+*2:], uint16())
}
return ,
}
}
func ( *conn) (, string) *stmt {
:= &stmt{cn: , name: }
:= .writeBuf('P')
.string(.name)
.string()
.int16(0)
.next('D')
.byte('S')
.string(.name)
.next('S')
.send()
.readParseResponse()
.paramTyps, .colNames, .colTyps = .readStatementDescribeResponse()
.colFmts, .colFmtData = decideColumnFormats(.colTyps, .disablePreparedBinaryResult)
.readReadyForQuery()
return
}
func ( *conn) ( string) ( driver.Stmt, error) {
if .bad {
return nil, driver.ErrBadConn
}
defer .errRecover(&)
if len() >= 4 && strings.EqualFold([:4], "COPY") {
, := .prepareCopyIn()
if == nil {
.inCopy = true
}
return ,
}
return .prepareTo(, .gname()), nil
}
defer .errRecover(&)
return .sendSimpleMessage('X')
}
if len() == 0 {
return .simpleQuery()
}
if .binaryParameters {
.sendBinaryModeQuery(, )
.readParseResponse()
.readBindResponse()
:= &rows{cn: }
.rowsHeader = .readPortalDescribeResponse()
.postExecuteWorkaround()
return , nil
}
:= .prepareTo(, "")
.exec()
return &rows{
cn: ,
rowsHeader: .rowsHeader,
}, nil
}
, , := .simpleExec()
return ,
}
if .binaryParameters {
.sendBinaryModeQuery(, )
.readParseResponse()
.readBindResponse()
.readPortalDescribeResponse()
.postExecuteWorkaround()
, _, = .readExecuteResponse("Execute")
return ,
func ( *conn) ( byte, *readBuf) {
if .saveMessageType != 0 {
.bad = true
errorf("unexpected saveMessageType %d", .saveMessageType)
}
.saveMessageType =
.saveMessageBuffer = *
}
if .saveMessageType != 0 {
:= .saveMessageType
* = .saveMessageBuffer
.saveMessageType = 0
.saveMessageBuffer = nil
return , nil
}
:= .scratch[:5]
, := io.ReadFull(.buf, )
if != nil {
return 0,
}
func ( *conn) () ( byte, *readBuf) {
for {
var error
= &readBuf{}
, = .recvMessage()
if != nil {
panic()
}
switch {
case 'E':
panic(parseError())
default:
return
}
}
}
case 'S':
.processParameterStatus()
default:
return
}
}
}
func ( string) bool {
switch {
case "host", "port":
return true
case "password":
return true
case "sslmode", "sslcert", "sslkey", "sslrootcert":
return true
case "fallback_application_name":
return true
case "connect_timeout":
return true
case "disable_prepared_binary_result":
return true
case "binary_parameters":
return true
default:
return false
}
}
func ( *conn) ( values) {
:= .writeBuf(0)
for , := range {
continue
if == "dbname" {
= "database"
}
.string()
.string()
}
.string("")
if := .sendStartupPacket(); != nil {
panic()
}
for {
, := .recv()
switch {
case 'K':
.processBackendKeyData()
case 'S':
.processParameterStatus()
case 'R':
.auth(, )
case 'Z':
.processReadyForQuery()
return
default:
errorf("unknown response for startup: %q", )
}
}
}
func ( *conn) ( *readBuf, values) {
switch := .int32(); {
case 3:
:= .writeBuf('p')
.string(["password"])
.send()
, := .recv()
if != 'R' {
errorf("unexpected password response: %q", )
}
if .int32() != 0 {
errorf("unexpected authentication response: %q", )
}
case 5:
:= string(.next(4))
:= .writeBuf('p')
.string("md5" + md5s(md5s(["password"]+["user"])+))
.send()
, := .recv()
if != 'R' {
errorf("unexpected password response: %q", )
}
if .int32() != 0 {
errorf("unexpected authentication response: %q", )
}
case 10:
:= scram.NewClient(sha256.New, ["user"], ["password"])
.Step(nil)
if .Err() != nil {
errorf("SCRAM-SHA-256 error: %s", .Err().Error())
}
:= .Out()
:= .writeBuf('p')
.string("SCRAM-SHA-256")
.int32(len())
.bytes()
.send()
, := .recv()
if != 'R' {
errorf("unexpected password response: %q", )
}
if .int32() != 11 {
errorf("unexpected authentication response: %q", )
}
:= .next(len(*))
.Step()
if .Err() != nil {
errorf("SCRAM-SHA-256 error: %s", .Err().Error())
}
= .Out()
= .writeBuf('p')
.bytes()
.send()
, = .recv()
if != 'R' {
errorf("unexpected password response: %q", )
}
if .int32() != 12 {
errorf("unexpected authentication response: %q", )
}
= .next(len(*))
.Step()
if .Err() != nil {
errorf("SCRAM-SHA-256 error: %s", .Err().Error())
}
default:
errorf("unknown authentication response: %d", )
}
}
type format int
const formatText format = 0
const formatBinary format = 1
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
var colFmtDataAllText = []byte{0, 0}
type stmt struct {
cn *conn
name string
rowsHeader
colFmtData []byte
paramTyps []oid.Oid
closed bool
}
func ( *stmt) () ( error) {
if .closed {
return nil
}
if .cn.bad {
return driver.ErrBadConn
}
defer .cn.errRecover(&)
:= .cn.writeBuf('C')
.byte('S')
.string(.name)
.cn.send()
.cn.send(.cn.writeBuf('S'))
, := .cn.recv1()
if != '3' {
.cn.bad = true
errorf("unexpected close response: %q", )
}
.closed = true
, := .cn.recv1()
if != 'Z' {
.cn.bad = true
errorf("expected ready for query, but got: %q", )
}
.cn.processReadyForQuery()
return nil
}
func ( *stmt) ( []driver.Value) ( driver.Rows, error) {
if .cn.bad {
return nil, driver.ErrBadConn
}
defer .cn.errRecover(&)
.exec()
return &rows{
cn: .cn,
rowsHeader: .rowsHeader,
}, nil
}
func ( *stmt) ( []driver.Value) ( driver.Result, error) {
if .cn.bad {
return nil, driver.ErrBadConn
}
defer .cn.errRecover(&)
.exec()
, _, = .cn.readExecuteResponse("simple query")
return ,
}
func ( *stmt) ( []driver.Value) {
if len() >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len())
}
if len() != len(.paramTyps) {
errorf("got %d parameters but the statement requires %d", len(), len(.paramTyps))
}
:= .cn
:= .writeBuf('B')
.byte(0) // unnamed portal
.string(.name)
if .binaryParameters {
.sendBinaryParameters(, )
} else {
.int16(0)
.int16(len())
for , := range {
if == nil {
.int32(-1)
} else {
:= encode(&.parameterStatus, , .paramTyps[])
.int32(len())
.bytes()
}
}
}
.bytes(.colFmtData)
.next('E')
.byte(0)
.int32(0)
.next('S')
.send()
.readBindResponse()
.postExecuteWorkaround()
}
func ( *stmt) () int {
return len(.paramTyps)
}
if == nil {
return driver.RowsAffected(0),
}
, := strconv.ParseInt(*, 10, 64)
if != nil {
.bad = true
errorf("could not parse commandTag: %s", )
}
return driver.RowsAffected(),
}
type rowsHeader struct {
colNames []string
colTyps []fieldDesc
colFmts []format
}
type rows struct {
cn *conn
finish func()
rowsHeader
done bool
rb readBuf
result driver.Result
tag string
next *rowsHeader
}
func ( *rows) () error {
if := .finish; != nil {
defer ()
if .done {
return nil
}
default:
return
}
}
}
func ( *rows) () []string {
return .colNames
}
func ( *rows) () driver.Result {
if .result == nil {
return emptyRows
}
return .result
}
func ( *rows) () string {
return .tag
}
func ( *rows) ( []driver.Value) ( error) {
if .done {
return io.EOF
}
:= .cn
if .bad {
return driver.ErrBadConn
}
defer .errRecover(&)
for {
:= .recv1Buf(&.rb)
switch {
case 'E':
= parseError(&.rb)
case 'C', 'I':
if == 'C' {
.result, .tag = .parseComplete(.rb.string())
}
continue
case 'Z':
.processReadyForQuery(&.rb)
.done = true
if != nil {
return
}
return io.EOF
case 'D':
:= .rb.int16()
if != nil {
.bad = true
errorf("unexpected DataRow after error %s", )
}
if < len() {
= [:]
}
for := range {
:= .rb.int32()
if == -1 {
[] = nil
continue
}
[] = decode(&.parameterStatus, .rb.next(), .colTyps[].OID, .colFmts[])
}
return
case 'T':
:= parsePortalRowDescribe(&.rb)
.next = &
return io.EOF
default:
errorf("unexpected message after execute: %q", )
}
}
}
func ( *rows) () bool {
:= .next != nil && !.done
return
}
func ( *rows) () error {
if .next == nil {
return io.EOF
}
.rowsHeader = *.next
.next = nil
return nil
}
var []int
for , := range {
, := .([]byte)
if {
if == nil {
= make([]int, len())
}
[] = 1
}
}
if == nil {
.int16(0)
} else {
.int16(len())
for , := range {
.int16()
}
}
.int16(len())
for , := range {
if == nil {
.int32(-1)
} else {
:= binaryEncode(&.parameterStatus, )
.int32(len())
.bytes()
}
}
}
func ( *conn) ( string, []driver.Value) {
if len() >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len())
}
:= .writeBuf('P')
.byte(0) // unnamed statement
.string()
.int16(0)
.next('B')
.int16(0) // unnamed portal and statement
.sendBinaryParameters(, )
.bytes(colFmtDataAllText)
.next('D')
.byte('P')
.byte(0) // unnamed portal
.next('E')
.byte(0)
.int32(0)
.next('S')
.send()
}
func ( *conn) ( *readBuf) {
var error
:= .string()
switch {
case "server_version":
var int
var int
var int
_, = fmt.Sscanf(.string(), "%d.%d.%d", &, &, &)
if == nil {
.parameterStatus.serverVersion = *10000 + *100 +
}
case "TimeZone":
.parameterStatus.currentLocation, = time.LoadLocation(.string())
if != nil {
.parameterStatus.currentLocation = nil
}
}
}
func ( *conn) ( *readBuf) {
.txnStatus = transactionStatus(.byte())
}
func ( *conn) () {
, := .recv1()
switch {
case 'Z':
.processReadyForQuery()
return
default:
.bad = true
errorf("unexpected message %q; expected ReadyForQuery", )
}
}
func ( *conn) ( *readBuf) {
.processID = .int32()
.secretKey = .int32()
}
func ( *conn) () {
, := .recv1()
switch {
case '1':
return
case 'E':
:= parseError()
.readReadyForQuery()
panic()
default:
.bad = true
errorf("unexpected Parse response %q", )
}
}
func ( *conn) () ( []oid.Oid, []string, []fieldDesc) {
for {
, := .recv1()
switch {
case 't':
:= .int16()
= make([]oid.Oid, )
for := range {
[] = .oid()
}
case 'n':
return , nil, nil
case 'T':
, = parseStatementRowDescribe()
return , ,
case 'E':
:= parseError()
.readReadyForQuery()
panic()
default:
.bad = true
errorf("unexpected Describe statement response %q", )
}
}
}
func ( *conn) () rowsHeader {
, := .recv1()
switch {
case 'T':
return parsePortalRowDescribe()
case 'n':
return rowsHeader{}
case 'E':
:= parseError()
.readReadyForQuery()
panic()
default:
.bad = true
errorf("unexpected Describe response %q", )
}
panic("not reached")
}
func ( *conn) () {
, := .recv1()
switch {
case '2':
return
case 'E':
:= parseError()
.readReadyForQuery()
panic()
default:
.bad = true
errorf("unexpected Bind response %q", )
}
}
for {
, := .recv1()
switch {
case 'E':
:= parseError()
.readReadyForQuery()
panic()
.saveMessage(, )
return
default:
.bad = true
errorf("unexpected message during extended query execution: %q", )
}
}
}
func ( *conn) ( string) ( driver.Result, string, error) {
for {
, := .recv1()
switch {
case 'C':
if != nil {
.bad = true
errorf("unexpected CommandComplete after error %s", )
}
, = .parseComplete(.string())
case 'Z':
.processReadyForQuery()
if == nil && == nil {
= errUnexpectedReady
}
return , ,
case 'E':
= parseError()
case 'T', 'D', 'I':
if != nil {
.bad = true
errorf("unexpected %q after error %s", , )
}
if == 'I' {
= emptyRows
switch [0] {
case "PGHOST":
("host")
case "PGHOSTADDR":
()
case "PGPORT":
("port")
case "PGDATABASE":
("dbname")
case "PGUSER":
("user")
case "PGPASSWORD":
("password")
case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
()
case "PGOPTIONS":
("options")
case "PGAPPNAME":
("application_name")
case "PGSSLMODE":
("sslmode")
case "PGSSLCERT":
("sslcert")
case "PGSSLKEY":
("sslkey")
case "PGSSLROOTCERT":
("sslrootcert")
case "PGREQUIRESSL", "PGSSLCRL":
()
case "PGREQUIREPEER":
()
case "PGKRBSRVNAME", "PGGSSLIB":
()
case "PGCONNECT_TIMEOUT":
("connect_timeout")
case "PGCLIENTENCODING":
("client_encoding")
case "PGDATESTYLE":
("datestyle")
case "PGTZ":
("timezone")
case "PGGEQO":
("geqo")
case "PGSYSCONFDIR", "PGLOCALEDIR":
()
}
}
return
}
:= strings.Map(alnumLowerASCII, )
return == "utf8" || == "unicode"
}
func ( rune) rune {
if 'A' <= && <= 'Z' {
return + ('a' - 'A')
}
if 'a' <= && <= 'z' || '0' <= && <= '9' {
return
}
return -1 // discard
![]() |
The pages are generated with Golds v0.3.2-preview. (GOOS=darwin GOARCH=amd64) Golds is a Go 101 project developed by Tapir Liu. PR and bug reports are welcome and can be submitted to the issue list. Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds. |