forked from gitea/gitea
parent
280ebcbf7c
commit
9d4c1ddfa1
|
@ -294,7 +294,7 @@
|
||||||
[[projects]]
|
[[projects]]
|
||||||
name = "github.com/go-sql-driver/mysql"
|
name = "github.com/go-sql-driver/mysql"
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
revision = "ce924a41eea897745442daaa1739089b0f3f561d"
|
revision = "d523deb1b23d913de5bdada721a6071e71283618"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
name = "github.com/go-xorm/builder"
|
name = "github.com/go-xorm/builder"
|
||||||
|
@ -873,6 +873,6 @@
|
||||||
[solve-meta]
|
[solve-meta]
|
||||||
analyzer-name = "dep"
|
analyzer-name = "dep"
|
||||||
analyzer-version = 1
|
analyzer-version = 1
|
||||||
inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687"
|
inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8"
|
||||||
solver-name = "gps-cdcl"
|
solver-name = "gps-cdcl"
|
||||||
solver-version = 1
|
solver-version = 1
|
||||||
|
|
|
@ -40,6 +40,10 @@ ignored = ["google.golang.org/appengine*"]
|
||||||
#version = "0.6.5"
|
#version = "0.6.5"
|
||||||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03"
|
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03"
|
||||||
|
|
||||||
|
[[override]]
|
||||||
|
name = "github.com/go-sql-driver/mysql"
|
||||||
|
revision = "d523deb1b23d913de5bdada721a6071e71283618"
|
||||||
|
|
||||||
[[override]]
|
[[override]]
|
||||||
name = "github.com/gorilla/mux"
|
name = "github.com/gorilla/mux"
|
||||||
revision = "757bef944d0f21880861c2dd9c871ca543023cba"
|
revision = "757bef944d0f21880861c2dd9c871ca543023cba"
|
||||||
|
|
|
@ -12,34 +12,63 @@
|
||||||
# Individual Persons
|
# Individual Persons
|
||||||
|
|
||||||
Aaron Hopkins <go-sql-driver at die.net>
|
Aaron Hopkins <go-sql-driver at die.net>
|
||||||
|
Achille Roussel <achille.roussel at gmail.com>
|
||||||
|
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
|
||||||
|
Andrew Reid <andrew.reid at tixtrack.com>
|
||||||
Arne Hormann <arnehormann at gmail.com>
|
Arne Hormann <arnehormann at gmail.com>
|
||||||
|
Asta Xie <xiemengjun at gmail.com>
|
||||||
|
Bulat Gaifullin <gaifullinbf at gmail.com>
|
||||||
Carlos Nieto <jose.carlos at menteslibres.net>
|
Carlos Nieto <jose.carlos at menteslibres.net>
|
||||||
Chris Moos <chris at tech9computers.com>
|
Chris Moos <chris at tech9computers.com>
|
||||||
|
Craig Wilson <craiggwilson at gmail.com>
|
||||||
|
Daniel Montoya <dsmontoyam at gmail.com>
|
||||||
Daniel Nichter <nil at codenode.com>
|
Daniel Nichter <nil at codenode.com>
|
||||||
Daniël van Eeden <git at myname.nl>
|
Daniël van Eeden <git at myname.nl>
|
||||||
|
Dave Protasowski <dprotaso at gmail.com>
|
||||||
DisposaBoy <disposaboy at dby.me>
|
DisposaBoy <disposaboy at dby.me>
|
||||||
|
Egor Smolyakov <egorsmkv at gmail.com>
|
||||||
|
Evan Shaw <evan at vendhq.com>
|
||||||
Frederick Mayle <frederickmayle at gmail.com>
|
Frederick Mayle <frederickmayle at gmail.com>
|
||||||
Gustavo Kristic <gkristic at gmail.com>
|
Gustavo Kristic <gkristic at gmail.com>
|
||||||
|
Hajime Nakagami <nakagami at gmail.com>
|
||||||
Hanno Braun <mail at hannobraun.com>
|
Hanno Braun <mail at hannobraun.com>
|
||||||
Henri Yandell <flamefew at gmail.com>
|
Henri Yandell <flamefew at gmail.com>
|
||||||
Hirotaka Yamamoto <ymmt2005 at gmail.com>
|
Hirotaka Yamamoto <ymmt2005 at gmail.com>
|
||||||
|
ICHINOSE Shogo <shogo82148 at gmail.com>
|
||||||
INADA Naoki <songofacandy at gmail.com>
|
INADA Naoki <songofacandy at gmail.com>
|
||||||
|
Jacek Szwec <szwec.jacek at gmail.com>
|
||||||
James Harr <james.harr at gmail.com>
|
James Harr <james.harr at gmail.com>
|
||||||
|
Jeff Hodges <jeff at somethingsimilar.com>
|
||||||
|
Jeffrey Charles <jeffreycharles at gmail.com>
|
||||||
Jian Zhen <zhenjl at gmail.com>
|
Jian Zhen <zhenjl at gmail.com>
|
||||||
Joshua Prunier <joshua.prunier at gmail.com>
|
Joshua Prunier <joshua.prunier at gmail.com>
|
||||||
Julien Lefevre <julien.lefevr at gmail.com>
|
Julien Lefevre <julien.lefevr at gmail.com>
|
||||||
Julien Schmidt <go-sql-driver at julienschmidt.com>
|
Julien Schmidt <go-sql-driver at julienschmidt.com>
|
||||||
|
Justin Li <jli at j-li.net>
|
||||||
|
Justin Nuß <nuss.justin at gmail.com>
|
||||||
Kamil Dziedzic <kamil at klecza.pl>
|
Kamil Dziedzic <kamil at klecza.pl>
|
||||||
Kevin Malachowski <kevin at chowski.com>
|
Kevin Malachowski <kevin at chowski.com>
|
||||||
|
Kieron Woodhouse <kieron.woodhouse at infosum.com>
|
||||||
Lennart Rudolph <lrudolph at hmc.edu>
|
Lennart Rudolph <lrudolph at hmc.edu>
|
||||||
Leonardo YongUk Kim <dalinaum at gmail.com>
|
Leonardo YongUk Kim <dalinaum at gmail.com>
|
||||||
|
Linh Tran Tuan <linhduonggnu at gmail.com>
|
||||||
|
Lion Yang <lion at aosc.xyz>
|
||||||
Luca Looz <luca.looz92 at gmail.com>
|
Luca Looz <luca.looz92 at gmail.com>
|
||||||
Lucas Liu <extrafliu at gmail.com>
|
Lucas Liu <extrafliu at gmail.com>
|
||||||
Luke Scott <luke at webconnex.com>
|
Luke Scott <luke at webconnex.com>
|
||||||
|
Maciej Zimnoch <maciej.zimnoch at codilime.com>
|
||||||
Michael Woolnough <michael.woolnough at gmail.com>
|
Michael Woolnough <michael.woolnough at gmail.com>
|
||||||
Nicola Peduzzi <thenikso at gmail.com>
|
Nicola Peduzzi <thenikso at gmail.com>
|
||||||
|
Olivier Mengué <dolmen at cpan.org>
|
||||||
|
oscarzhao <oscarzhaosl at gmail.com>
|
||||||
Paul Bonser <misterpib at gmail.com>
|
Paul Bonser <misterpib at gmail.com>
|
||||||
|
Peter Schultz <peter.schultz at classmarkets.com>
|
||||||
|
Rebecca Chin <rchin at pivotal.io>
|
||||||
|
Reed Allman <rdallman10 at gmail.com>
|
||||||
|
Richard Wilkes <wilkes at me.com>
|
||||||
|
Robert Russell <robert at rrbrussell.com>
|
||||||
Runrioter Wung <runrioter at gmail.com>
|
Runrioter Wung <runrioter at gmail.com>
|
||||||
|
Shuode Li <elemount at qq.com>
|
||||||
Soroush Pour <me at soroushjp.com>
|
Soroush Pour <me at soroushjp.com>
|
||||||
Stan Putrya <root.vagner at gmail.com>
|
Stan Putrya <root.vagner at gmail.com>
|
||||||
Stanley Gunawan <gunawan.stanley at gmail.com>
|
Stanley Gunawan <gunawan.stanley at gmail.com>
|
||||||
|
@ -51,5 +80,10 @@ Zhenye Xie <xiezhenye at gmail.com>
|
||||||
# Organizations
|
# Organizations
|
||||||
|
|
||||||
Barracuda Networks, Inc.
|
Barracuda Networks, Inc.
|
||||||
|
Counting Ltd.
|
||||||
Google Inc.
|
Google Inc.
|
||||||
|
InfoSum Ltd.
|
||||||
|
Keybase Inc.
|
||||||
|
Percona LLC
|
||||||
|
Pivotal Inc.
|
||||||
Stripe Inc.
|
Stripe Inc.
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"appengine/cloudsql"
|
"google.golang.org/appengine/cloudsql"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -0,0 +1,420 @@
|
||||||
|
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||||
|
//
|
||||||
|
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// server pub keys registry
|
||||||
|
var (
|
||||||
|
serverPubKeyLock sync.RWMutex
|
||||||
|
serverPubKeyRegistry map[string]*rsa.PublicKey
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterServerPubKey registers a server RSA public key which can be used to
|
||||||
|
// send data in a secure manner to the server without receiving the public key
|
||||||
|
// in a potentially insecure way from the server first.
|
||||||
|
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
|
||||||
|
//
|
||||||
|
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
|
||||||
|
// after registering it and may not be modified.
|
||||||
|
//
|
||||||
|
// data, err := ioutil.ReadFile("mykey.pem")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// block, _ := pem.Decode(data)
|
||||||
|
// if block == nil || block.Type != "PUBLIC KEY" {
|
||||||
|
// log.Fatal("failed to decode PEM block containing public key")
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
|
||||||
|
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
|
||||||
|
// } else {
|
||||||
|
// log.Fatal("not a RSA public key")
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
|
||||||
|
serverPubKeyLock.Lock()
|
||||||
|
if serverPubKeyRegistry == nil {
|
||||||
|
serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverPubKeyRegistry[name] = pubKey
|
||||||
|
serverPubKeyLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeregisterServerPubKey removes the public key registered with the given name.
|
||||||
|
func DeregisterServerPubKey(name string) {
|
||||||
|
serverPubKeyLock.Lock()
|
||||||
|
if serverPubKeyRegistry != nil {
|
||||||
|
delete(serverPubKeyRegistry, name)
|
||||||
|
}
|
||||||
|
serverPubKeyLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
|
||||||
|
serverPubKeyLock.RLock()
|
||||||
|
if v, ok := serverPubKeyRegistry[name]; ok {
|
||||||
|
pubKey = v
|
||||||
|
}
|
||||||
|
serverPubKeyLock.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash password using pre 4.1 (old password) method
|
||||||
|
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
|
||||||
|
type myRnd struct {
|
||||||
|
seed1, seed2 uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
const myRndMaxVal = 0x3FFFFFFF
|
||||||
|
|
||||||
|
// Pseudo random number generator
|
||||||
|
func newMyRnd(seed1, seed2 uint32) *myRnd {
|
||||||
|
return &myRnd{
|
||||||
|
seed1: seed1 % myRndMaxVal,
|
||||||
|
seed2: seed2 % myRndMaxVal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tested to be equivalent to MariaDB's floating point variant
|
||||||
|
// http://play.golang.org/p/QHvhd4qved
|
||||||
|
// http://play.golang.org/p/RG0q4ElWDx
|
||||||
|
func (r *myRnd) NextByte() byte {
|
||||||
|
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
|
||||||
|
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
|
||||||
|
|
||||||
|
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate binary hash from byte string using insecure pre 4.1 method
|
||||||
|
func pwHash(password []byte) (result [2]uint32) {
|
||||||
|
var add uint32 = 7
|
||||||
|
var tmp uint32
|
||||||
|
|
||||||
|
result[0] = 1345345333
|
||||||
|
result[1] = 0x12345671
|
||||||
|
|
||||||
|
for _, c := range password {
|
||||||
|
// skip spaces and tabs in password
|
||||||
|
if c == ' ' || c == '\t' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp = uint32(c)
|
||||||
|
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
|
||||||
|
result[1] += (result[1] << 8) ^ result[0]
|
||||||
|
add += tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove sign bit (1<<31)-1)
|
||||||
|
result[0] &= 0x7FFFFFFF
|
||||||
|
result[1] &= 0x7FFFFFFF
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash password using insecure pre 4.1 method
|
||||||
|
func scrambleOldPassword(scramble []byte, password string) []byte {
|
||||||
|
if len(password) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
scramble = scramble[:8]
|
||||||
|
|
||||||
|
hashPw := pwHash([]byte(password))
|
||||||
|
hashSc := pwHash(scramble)
|
||||||
|
|
||||||
|
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
|
||||||
|
|
||||||
|
var out [8]byte
|
||||||
|
for i := range out {
|
||||||
|
out[i] = r.NextByte() + 64
|
||||||
|
}
|
||||||
|
|
||||||
|
mask := r.NextByte()
|
||||||
|
for i := range out {
|
||||||
|
out[i] ^= mask
|
||||||
|
}
|
||||||
|
|
||||||
|
return out[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash password using 4.1+ method (SHA1)
|
||||||
|
func scramblePassword(scramble []byte, password string) []byte {
|
||||||
|
if len(password) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stage1Hash = SHA1(password)
|
||||||
|
crypt := sha1.New()
|
||||||
|
crypt.Write([]byte(password))
|
||||||
|
stage1 := crypt.Sum(nil)
|
||||||
|
|
||||||
|
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||||
|
// inner Hash
|
||||||
|
crypt.Reset()
|
||||||
|
crypt.Write(stage1)
|
||||||
|
hash := crypt.Sum(nil)
|
||||||
|
|
||||||
|
// outer Hash
|
||||||
|
crypt.Reset()
|
||||||
|
crypt.Write(scramble)
|
||||||
|
crypt.Write(hash)
|
||||||
|
scramble = crypt.Sum(nil)
|
||||||
|
|
||||||
|
// token = scrambleHash XOR stage1Hash
|
||||||
|
for i := range scramble {
|
||||||
|
scramble[i] ^= stage1[i]
|
||||||
|
}
|
||||||
|
return scramble
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash password using MySQL 8+ method (SHA256)
|
||||||
|
func scrambleSHA256Password(scramble []byte, password string) []byte {
|
||||||
|
if len(password) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
|
||||||
|
|
||||||
|
crypt := sha256.New()
|
||||||
|
crypt.Write([]byte(password))
|
||||||
|
message1 := crypt.Sum(nil)
|
||||||
|
|
||||||
|
crypt.Reset()
|
||||||
|
crypt.Write(message1)
|
||||||
|
message1Hash := crypt.Sum(nil)
|
||||||
|
|
||||||
|
crypt.Reset()
|
||||||
|
crypt.Write(message1Hash)
|
||||||
|
crypt.Write(scramble)
|
||||||
|
message2 := crypt.Sum(nil)
|
||||||
|
|
||||||
|
for i := range message1 {
|
||||||
|
message1[i] ^= message2[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return message1
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
|
||||||
|
plain := make([]byte, len(password)+1)
|
||||||
|
copy(plain, password)
|
||||||
|
for i := range plain {
|
||||||
|
j := i % len(seed)
|
||||||
|
plain[i] ^= seed[j]
|
||||||
|
}
|
||||||
|
sha1 := sha1.New()
|
||||||
|
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
|
||||||
|
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return mc.writeAuthSwitchPacket(enc, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
|
||||||
|
switch plugin {
|
||||||
|
case "caching_sha2_password":
|
||||||
|
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
|
||||||
|
return authResp, (authResp == nil), nil
|
||||||
|
|
||||||
|
case "mysql_old_password":
|
||||||
|
if !mc.cfg.AllowOldPasswords {
|
||||||
|
return nil, false, ErrOldPassword
|
||||||
|
}
|
||||||
|
// Note: there are edge cases where this should work but doesn't;
|
||||||
|
// this is currently "wontfix":
|
||||||
|
// https://github.com/go-sql-driver/mysql/issues/184
|
||||||
|
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
|
||||||
|
return authResp, true, nil
|
||||||
|
|
||||||
|
case "mysql_clear_password":
|
||||||
|
if !mc.cfg.AllowCleartextPasswords {
|
||||||
|
return nil, false, ErrCleartextPassword
|
||||||
|
}
|
||||||
|
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
||||||
|
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
||||||
|
return []byte(mc.cfg.Passwd), true, nil
|
||||||
|
|
||||||
|
case "mysql_native_password":
|
||||||
|
if !mc.cfg.AllowNativePasswords {
|
||||||
|
return nil, false, ErrNativePassword
|
||||||
|
}
|
||||||
|
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||||
|
// Native password authentication only need and will need 20-byte challenge.
|
||||||
|
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
|
||||||
|
return authResp, false, nil
|
||||||
|
|
||||||
|
case "sha256_password":
|
||||||
|
if len(mc.cfg.Passwd) == 0 {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||||
|
// write cleartext auth packet
|
||||||
|
return []byte(mc.cfg.Passwd), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := mc.cfg.pubKey
|
||||||
|
if pubKey == nil {
|
||||||
|
// request public key from server
|
||||||
|
return []byte{1}, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encrypted password
|
||||||
|
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
|
||||||
|
return enc, false, err
|
||||||
|
|
||||||
|
default:
|
||||||
|
errLog.Print("unknown auth plugin:", plugin)
|
||||||
|
return nil, false, ErrUnknownPlugin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||||
|
// Read Result Packet
|
||||||
|
authData, newPlugin, err := mc.readAuthResult()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle auth plugin switch, if requested
|
||||||
|
if newPlugin != "" {
|
||||||
|
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
|
||||||
|
// sent and we have to keep using the cipher sent in the init packet.
|
||||||
|
if authData == nil {
|
||||||
|
authData = oldAuthData
|
||||||
|
} else {
|
||||||
|
// copy data from read buffer to owned slice
|
||||||
|
copy(oldAuthData, authData)
|
||||||
|
}
|
||||||
|
|
||||||
|
plugin = newPlugin
|
||||||
|
|
||||||
|
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read Result Packet
|
||||||
|
authData, newPlugin, err = mc.readAuthResult()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not allow to change the auth plugin more than once
|
||||||
|
if newPlugin != "" {
|
||||||
|
return ErrMalformPkt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch plugin {
|
||||||
|
|
||||||
|
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
|
||||||
|
case "caching_sha2_password":
|
||||||
|
switch len(authData) {
|
||||||
|
case 0:
|
||||||
|
return nil // auth successful
|
||||||
|
case 1:
|
||||||
|
switch authData[0] {
|
||||||
|
case cachingSha2PasswordFastAuthSuccess:
|
||||||
|
if err = mc.readResultOK(); err == nil {
|
||||||
|
return nil // auth successful
|
||||||
|
}
|
||||||
|
|
||||||
|
case cachingSha2PasswordPerformFullAuthentication:
|
||||||
|
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||||
|
// write cleartext auth packet
|
||||||
|
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pubKey := mc.cfg.pubKey
|
||||||
|
if pubKey == nil {
|
||||||
|
// request public key from server
|
||||||
|
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||||
|
data[4] = cachingSha2PasswordRequestPublicKey
|
||||||
|
mc.writePacket(data)
|
||||||
|
|
||||||
|
// parse public key
|
||||||
|
data, err := mc.readPacket()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(data[1:])
|
||||||
|
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pubKey = pkix.(*rsa.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// send encrypted password
|
||||||
|
err = mc.sendEncryptedPassword(oldAuthData, pubKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mc.readResultOK()
|
||||||
|
|
||||||
|
default:
|
||||||
|
return ErrMalformPkt
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ErrMalformPkt
|
||||||
|
}
|
||||||
|
|
||||||
|
case "sha256_password":
|
||||||
|
switch len(authData) {
|
||||||
|
case 0:
|
||||||
|
return nil // auth successful
|
||||||
|
default:
|
||||||
|
block, _ := pem.Decode(authData)
|
||||||
|
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// send encrypted password
|
||||||
|
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return mc.readResultOK()
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil // auth successful
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
|
@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte {
|
||||||
// smaller than defaultBufSize
|
// smaller than defaultBufSize
|
||||||
// Only one buffer (total) can be used at a time.
|
// Only one buffer (total) can be used at a time.
|
||||||
func (b *buffer) takeSmallBuffer(length int) []byte {
|
func (b *buffer) takeSmallBuffer(length int) []byte {
|
||||||
if b.length == 0 {
|
if b.length > 0 {
|
||||||
return b.buf[:length]
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
return b.buf[:length]
|
||||||
}
|
}
|
||||||
|
|
||||||
// takeCompleteBuffer returns the complete existing buffer.
|
// takeCompleteBuffer returns the complete existing buffer.
|
||||||
// This can be used if the necessary buffer size is unknown.
|
// This can be used if the necessary buffer size is unknown.
|
||||||
// Only one buffer (total) can be used at a time.
|
// Only one buffer (total) can be used at a time.
|
||||||
func (b *buffer) takeCompleteBuffer() []byte {
|
func (b *buffer) takeCompleteBuffer() []byte {
|
||||||
if b.length == 0 {
|
if b.length > 0 {
|
||||||
return b.buf
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
return b.buf
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
const defaultCollation = "utf8_general_ci"
|
const defaultCollation = "utf8_general_ci"
|
||||||
|
const binaryCollation = "binary"
|
||||||
|
|
||||||
// A list of available collations mapped to the internal ID.
|
// A list of available collations mapped to the internal ID.
|
||||||
// To update this map use the following MySQL query:
|
// To update this map use the following MySQL query:
|
||||||
|
|
|
@ -10,12 +10,23 @@ package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// a copy of context.Context for Go 1.7 and earlier
|
||||||
|
type mysqlContext interface {
|
||||||
|
Done() <-chan struct{}
|
||||||
|
Err() error
|
||||||
|
|
||||||
|
// defined in context.Context, but not used in this driver:
|
||||||
|
// Deadline() (deadline time.Time, ok bool)
|
||||||
|
// Value(key interface{}) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
type mysqlConn struct {
|
type mysqlConn struct {
|
||||||
buf buffer
|
buf buffer
|
||||||
netConn net.Conn
|
netConn net.Conn
|
||||||
|
@ -29,7 +40,14 @@ type mysqlConn struct {
|
||||||
status statusFlag
|
status statusFlag
|
||||||
sequence uint8
|
sequence uint8
|
||||||
parseTime bool
|
parseTime bool
|
||||||
strict bool
|
|
||||||
|
// for context support (Go 1.8+)
|
||||||
|
watching bool
|
||||||
|
watcher chan<- mysqlContext
|
||||||
|
closech chan struct{}
|
||||||
|
finished chan<- struct{}
|
||||||
|
canceled atomicError // set non-nil if conn is canceled
|
||||||
|
closed atomicBool // set when conn is closed, before closech is closed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handles parameters set in DSN after the connection is established
|
// Handles parameters set in DSN after the connection is established
|
||||||
|
@ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) markBadConn(err error) error {
|
||||||
|
if mc == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err != errBadConnNoWrite {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
||||||
if mc.netConn == nil {
|
return mc.begin(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
|
||||||
|
if mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
errLog.Print(ErrInvalidConn)
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
err := mc.exec("START TRANSACTION")
|
var q string
|
||||||
|
if readOnly {
|
||||||
|
q = "START TRANSACTION READ ONLY"
|
||||||
|
} else {
|
||||||
|
q = "START TRANSACTION"
|
||||||
|
}
|
||||||
|
err := mc.exec(q)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return &mysqlTx{mc}, err
|
return &mysqlTx{mc}, err
|
||||||
}
|
}
|
||||||
|
return nil, mc.markBadConn(err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mysqlConn) Close() (err error) {
|
func (mc *mysqlConn) Close() (err error) {
|
||||||
// Makes Close idempotent
|
// Makes Close idempotent
|
||||||
if mc.netConn != nil {
|
if !mc.closed.IsSet() {
|
||||||
err = mc.writeCommandPacket(comQuit)
|
err = mc.writeCommandPacket(comQuit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) {
|
||||||
// is called before auth or on auth failure because MySQL will have already
|
// is called before auth or on auth failure because MySQL will have already
|
||||||
// closed the network connection.
|
// closed the network connection.
|
||||||
func (mc *mysqlConn) cleanup() {
|
func (mc *mysqlConn) cleanup() {
|
||||||
|
if !mc.closed.TrySet(true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Makes cleanup idempotent
|
// Makes cleanup idempotent
|
||||||
if mc.netConn != nil {
|
close(mc.closech)
|
||||||
|
if mc.netConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := mc.netConn.Close(); err != nil {
|
if err := mc.netConn.Close(); err != nil {
|
||||||
errLog.Print(err)
|
errLog.Print(err)
|
||||||
}
|
}
|
||||||
mc.netConn = nil
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) error() error {
|
||||||
|
if mc.closed.IsSet() {
|
||||||
|
if err := mc.canceled.Value(); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
mc.cfg = nil
|
return ErrInvalidConn
|
||||||
mc.buf.nc = nil
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||||
if mc.netConn == nil {
|
if mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
errLog.Print(ErrInvalidConn)
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
// Send command
|
// Send command
|
||||||
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, mc.markBadConn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt := &mysqlStmt{
|
stmt := &mysqlStmt{
|
||||||
|
@ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
||||||
if buf == nil {
|
if buf == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// can not take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return "", driver.ErrBadConn
|
return "", ErrInvalidConn
|
||||||
}
|
}
|
||||||
buf = buf[:0]
|
buf = buf[:0]
|
||||||
argPos := 0
|
argPos := 0
|
||||||
|
@ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||||
if mc.netConn == nil {
|
if mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
errLog.Print(ErrInvalidConn)
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
@ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
query = prepared
|
query = prepared
|
||||||
args = nil
|
|
||||||
}
|
}
|
||||||
mc.affectedRows = 0
|
mc.affectedRows = 0
|
||||||
mc.insertId = 0
|
mc.insertId = 0
|
||||||
|
@ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
|
||||||
insertId: int64(mc.insertId),
|
insertId: int64(mc.insertId),
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, mc.markBadConn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal function to execute commands
|
// Internal function to execute commands
|
||||||
func (mc *mysqlConn) exec(query string) error {
|
func (mc *mysqlConn) exec(query string) error {
|
||||||
// Send command
|
// Send command
|
||||||
err := mc.writeCommandPacketStr(comQuery, query)
|
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
|
||||||
if err != nil {
|
return mc.markBadConn(err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read Result
|
// Read Result
|
||||||
resLen, err := mc.readResultSetHeaderPacket()
|
resLen, err := mc.readResultSetHeaderPacket()
|
||||||
if err == nil && resLen > 0 {
|
if err != nil {
|
||||||
if err = mc.readUntilEOF(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = mc.readUntilEOF()
|
if resLen > 0 {
|
||||||
|
// columns
|
||||||
|
if err := mc.readUntilEOF(); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rows
|
||||||
|
if err := mc.readUntilEOF(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mc.discardResults()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||||
if mc.netConn == nil {
|
return mc.query(query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
|
||||||
|
if mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
errLog.Print(ErrInvalidConn)
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
@ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
query = prepared
|
query = prepared
|
||||||
args = nil
|
|
||||||
}
|
}
|
||||||
// Send command
|
// Send command
|
||||||
err := mc.writeCommandPacketStr(comQuery, query)
|
err := mc.writeCommandPacketStr(comQuery, query)
|
||||||
|
@ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
||||||
rows.mc = mc
|
rows.mc = mc
|
||||||
|
|
||||||
if resLen == 0 {
|
if resLen == 0 {
|
||||||
// no columns, no more data
|
rows.rs.done = true
|
||||||
return emptyRows{}, nil
|
|
||||||
|
switch err := rows.NextResultSet(); err {
|
||||||
|
case nil, io.EOF:
|
||||||
|
return rows, nil
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Columns
|
// Columns
|
||||||
rows.columns, err = mc.readColumns(resLen)
|
rows.rs.columns, err = mc.readColumns(resLen)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, mc.markBadConn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gets the value of the given MySQL System Variable
|
// Gets the value of the given MySQL System Variable
|
||||||
|
@ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows := new(textRows)
|
rows := new(textRows)
|
||||||
rows.mc = mc
|
rows.mc = mc
|
||||||
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
||||||
|
|
||||||
if resLen > 0 {
|
if resLen > 0 {
|
||||||
// Columns
|
// Columns
|
||||||
|
@ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finish is called when the query has canceled.
|
||||||
|
func (mc *mysqlConn) cancel(err error) {
|
||||||
|
mc.canceled.Set(err)
|
||||||
|
mc.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
// finish is called when the query has succeeded.
|
||||||
|
func (mc *mysqlConn) finish() {
|
||||||
|
if !mc.watching || mc.finished == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case mc.finished <- struct{}{}:
|
||||||
|
mc.watching = false
|
||||||
|
case <-mc.closech:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||||
|
//
|
||||||
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ping implements driver.Pinger interface
|
||||||
|
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||||
|
if mc.closed.IsSet() {
|
||||||
|
errLog.Print(ErrInvalidConn)
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = mc.watchCancel(ctx); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer mc.finish()
|
||||||
|
|
||||||
|
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return mc.readResultOK()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginTx implements driver.ConnBeginTx interface
|
||||||
|
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||||
|
if err := mc.watchCancel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer mc.finish()
|
||||||
|
|
||||||
|
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
|
||||||
|
level, err := mapIsolationLevel(opts.Isolation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mc.begin(opts.ReadOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
dargs, err := namedValueToValue(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mc.watchCancel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := mc.query(query, dargs)
|
||||||
|
if err != nil {
|
||||||
|
mc.finish()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows.finish = mc.finish
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||||
|
dargs, err := namedValueToValue(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mc.watchCancel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer mc.finish()
|
||||||
|
|
||||||
|
return mc.Exec(query, dargs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||||
|
if err := mc.watchCancel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := mc.Prepare(query)
|
||||||
|
mc.finish()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
default:
|
||||||
|
case <-ctx.Done():
|
||||||
|
stmt.Close()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
return stmt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
dargs, err := namedValueToValue(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := stmt.query(dargs)
|
||||||
|
if err != nil {
|
||||||
|
stmt.mc.finish()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows.finish = stmt.mc.finish
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||||
|
dargs, err := namedValueToValue(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer stmt.mc.finish()
|
||||||
|
|
||||||
|
return stmt.Exec(dargs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||||
|
if mc.watching {
|
||||||
|
// Reach here if canceled,
|
||||||
|
// so the connection is already invalid
|
||||||
|
mc.cleanup()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ctx.Done() == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mc.watching = true
|
||||||
|
select {
|
||||||
|
default:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
if mc.watcher == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mc.watcher <- ctx
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) startWatcher() {
|
||||||
|
watcher := make(chan mysqlContext, 1)
|
||||||
|
mc.watcher = watcher
|
||||||
|
finished := make(chan struct{})
|
||||||
|
mc.finished = finished
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var ctx mysqlContext
|
||||||
|
select {
|
||||||
|
case ctx = <-watcher:
|
||||||
|
case <-mc.closech:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
mc.cancel(ctx.Err())
|
||||||
|
case <-finished:
|
||||||
|
case <-mc.closech:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||||
|
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetSession implements driver.SessionResetter.
|
||||||
|
// (From Go 1.10)
|
||||||
|
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
||||||
|
if mc.closed.IsSet() {
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -9,7 +9,9 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
const (
|
const (
|
||||||
minProtocolVersion byte = 10
|
defaultAuthPlugin = "mysql_native_password"
|
||||||
|
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
|
||||||
|
minProtocolVersion = 10
|
||||||
maxPacketSize = 1<<24 - 1
|
maxPacketSize = 1<<24 - 1
|
||||||
timeFormat = "2006-01-02 15:04:05.999999"
|
timeFormat = "2006-01-02 15:04:05.999999"
|
||||||
)
|
)
|
||||||
|
@ -19,6 +21,7 @@ const (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
iOK byte = 0x00
|
iOK byte = 0x00
|
||||||
|
iAuthMoreData byte = 0x01
|
||||||
iLocalInFile byte = 0xfb
|
iLocalInFile byte = 0xfb
|
||||||
iEOF byte = 0xfe
|
iEOF byte = 0xfe
|
||||||
iERR byte = 0xff
|
iERR byte = 0xff
|
||||||
|
@ -87,8 +90,10 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
|
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
|
||||||
|
type fieldType byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fieldTypeDecimal byte = iota
|
fieldTypeDecimal fieldType = iota
|
||||||
fieldTypeTiny
|
fieldTypeTiny
|
||||||
fieldTypeShort
|
fieldTypeShort
|
||||||
fieldTypeLong
|
fieldTypeLong
|
||||||
|
@ -107,7 +112,7 @@ const (
|
||||||
fieldTypeBit
|
fieldTypeBit
|
||||||
)
|
)
|
||||||
const (
|
const (
|
||||||
fieldTypeJSON byte = iota + 0xf5
|
fieldTypeJSON fieldType = iota + 0xf5
|
||||||
fieldTypeNewDecimal
|
fieldTypeNewDecimal
|
||||||
fieldTypeEnum
|
fieldTypeEnum
|
||||||
fieldTypeSet
|
fieldTypeSet
|
||||||
|
@ -161,3 +166,9 @@ const (
|
||||||
statusInTransReadonly
|
statusInTransReadonly
|
||||||
statusSessionStateChanged
|
statusSessionStateChanged
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cachingSha2PasswordRequestPublicKey = 2
|
||||||
|
cachingSha2PasswordFastAuthSuccess = 3
|
||||||
|
cachingSha2PasswordPerformFullAuthentication = 4
|
||||||
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
// Package mysql provides a MySQL driver for Go's database/sql package
|
// Package mysql provides a MySQL driver for Go's database/sql package.
|
||||||
//
|
//
|
||||||
// The driver should be used via the database/sql package:
|
// The driver should be used via the database/sql package:
|
||||||
//
|
//
|
||||||
|
@ -20,8 +20,14 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// watcher interface is used for context support (From Go 1.8)
|
||||||
|
type watcher interface {
|
||||||
|
startWatcher()
|
||||||
|
}
|
||||||
|
|
||||||
// MySQLDriver is exported to make the driver directly accessible.
|
// MySQLDriver is exported to make the driver directly accessible.
|
||||||
// In general the driver is used via the database/sql package.
|
// In general the driver is used via the database/sql package.
|
||||||
type MySQLDriver struct{}
|
type MySQLDriver struct{}
|
||||||
|
@ -30,12 +36,17 @@ type MySQLDriver struct{}
|
||||||
// Custom dial functions must be registered with RegisterDial
|
// Custom dial functions must be registered with RegisterDial
|
||||||
type DialFunc func(addr string) (net.Conn, error)
|
type DialFunc func(addr string) (net.Conn, error)
|
||||||
|
|
||||||
var dials map[string]DialFunc
|
var (
|
||||||
|
dialsLock sync.RWMutex
|
||||||
|
dials map[string]DialFunc
|
||||||
|
)
|
||||||
|
|
||||||
// RegisterDial registers a custom dial function. It can then be used by the
|
// RegisterDial registers a custom dial function. It can then be used by the
|
||||||
// network address mynet(addr), where mynet is the registered new network.
|
// network address mynet(addr), where mynet is the registered new network.
|
||||||
// addr is passed as a parameter to the dial function.
|
// addr is passed as a parameter to the dial function.
|
||||||
func RegisterDial(net string, dial DialFunc) {
|
func RegisterDial(net string, dial DialFunc) {
|
||||||
|
dialsLock.Lock()
|
||||||
|
defer dialsLock.Unlock()
|
||||||
if dials == nil {
|
if dials == nil {
|
||||||
dials = make(map[string]DialFunc)
|
dials = make(map[string]DialFunc)
|
||||||
}
|
}
|
||||||
|
@ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
mc := &mysqlConn{
|
mc := &mysqlConn{
|
||||||
maxAllowedPacket: maxPacketSize,
|
maxAllowedPacket: maxPacketSize,
|
||||||
maxWriteSize: maxPacketSize - 1,
|
maxWriteSize: maxPacketSize - 1,
|
||||||
|
closech: make(chan struct{}),
|
||||||
}
|
}
|
||||||
mc.cfg, err = ParseDSN(dsn)
|
mc.cfg, err = ParseDSN(dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mc.parseTime = mc.cfg.ParseTime
|
mc.parseTime = mc.cfg.ParseTime
|
||||||
mc.strict = mc.cfg.Strict
|
|
||||||
|
|
||||||
// Connect to Server
|
// Connect to Server
|
||||||
if dial, ok := dials[mc.cfg.Net]; ok {
|
dialsLock.RLock()
|
||||||
|
dial, ok := dials[mc.cfg.Net]
|
||||||
|
dialsLock.RUnlock()
|
||||||
|
if ok {
|
||||||
mc.netConn, err = dial(mc.cfg.Addr)
|
mc.netConn, err = dial(mc.cfg.Addr)
|
||||||
} else {
|
} else {
|
||||||
nd := net.Dialer{Timeout: mc.cfg.Timeout}
|
nd := net.Dialer{Timeout: mc.cfg.Timeout}
|
||||||
|
@ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call startWatcher for context support (From Go 1.8)
|
||||||
|
if s, ok := interface{}(mc).(watcher); ok {
|
||||||
|
s.startWatcher()
|
||||||
|
}
|
||||||
|
|
||||||
mc.buf = newBuffer(mc.netConn)
|
mc.buf = newBuffer(mc.netConn)
|
||||||
|
|
||||||
// Set I/O timeouts
|
// Set I/O timeouts
|
||||||
|
@ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
mc.writeTimeout = mc.cfg.WriteTimeout
|
mc.writeTimeout = mc.cfg.WriteTimeout
|
||||||
|
|
||||||
// Reading Handshake Initialization Packet
|
// Reading Handshake Initialization Packet
|
||||||
cipher, err := mc.readInitPacket()
|
authData, plugin, err := mc.readHandshakePacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mc.cleanup()
|
mc.cleanup()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send Client Authentication Packet
|
// Send Client Authentication Packet
|
||||||
if err = mc.writeAuthPacket(cipher); err != nil {
|
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||||
|
if err != nil {
|
||||||
|
// try the default auth plugin, if using the requested plugin failed
|
||||||
|
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||||
|
plugin = defaultAuthPlugin
|
||||||
|
authResp, addNUL, err = mc.auth(authData, plugin)
|
||||||
|
if err != nil {
|
||||||
|
mc.cleanup()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
|
||||||
mc.cleanup()
|
mc.cleanup()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle response to auth packet, switch methods if possible
|
// Handle response to auth packet, switch methods if possible
|
||||||
if err = handleAuthResult(mc); err != nil {
|
if err = mc.handleAuthResult(authData, plugin); err != nil {
|
||||||
// Authentication failed and MySQL has already closed the connection
|
// Authentication failed and MySQL has already closed the connection
|
||||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
||||||
// Do not send COM_QUIT, just cleanup and return the error.
|
// Do not send COM_QUIT, just cleanup and return the error.
|
||||||
|
@ -134,43 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return mc, nil
|
return mc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAuthResult(mc *mysqlConn) error {
|
|
||||||
// Read Result Packet
|
|
||||||
cipher, err := mc.readResultOK()
|
|
||||||
if err == nil {
|
|
||||||
return nil // auth successful
|
|
||||||
}
|
|
||||||
|
|
||||||
if mc.cfg == nil {
|
|
||||||
return err // auth failed and retry not possible
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retry auth if configured to do so.
|
|
||||||
if mc.cfg.AllowOldPasswords && err == ErrOldPassword {
|
|
||||||
// Retry with old authentication method. Note: there are edge cases
|
|
||||||
// where this should work but doesn't; this is currently "wontfix":
|
|
||||||
// https://github.com/go-sql-driver/mysql/issues/184
|
|
||||||
if err = mc.writeOldAuthPacket(cipher); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = mc.readResultOK()
|
|
||||||
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
|
|
||||||
// Retry with clear text password for
|
|
||||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
|
||||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
|
||||||
if err = mc.writeClearAuthPacket(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = mc.readResultOK()
|
|
||||||
} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword {
|
|
||||||
if err = mc.writeNativeAuthPacket(cipher); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = mc.readResultOK()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sql.Register("mysql", &MySQLDriver{})
|
sql.Register("mysql", &MySQLDriver{})
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,11 +10,13 @@ package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -27,7 +29,9 @@ var (
|
||||||
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
|
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is a configuration parsed from a DSN string
|
// Config is a configuration parsed from a DSN string.
|
||||||
|
// If a new Config is created instead of being parsed from a DSN string,
|
||||||
|
// the NewConfig function should be used, which sets default values.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
User string // Username
|
User string // Username
|
||||||
Passwd string // Password (requires User)
|
Passwd string // Password (requires User)
|
||||||
|
@ -38,6 +42,8 @@ type Config struct {
|
||||||
Collation string // Connection collation
|
Collation string // Connection collation
|
||||||
Loc *time.Location // Location for time.Time values
|
Loc *time.Location // Location for time.Time values
|
||||||
MaxAllowedPacket int // Max packet size allowed
|
MaxAllowedPacket int // Max packet size allowed
|
||||||
|
ServerPubKey string // Server public key name
|
||||||
|
pubKey *rsa.PublicKey // Server public key
|
||||||
TLSConfig string // TLS configuration name
|
TLSConfig string // TLS configuration name
|
||||||
tls *tls.Config // TLS configuration
|
tls *tls.Config // TLS configuration
|
||||||
Timeout time.Duration // Dial timeout
|
Timeout time.Duration // Dial timeout
|
||||||
|
@ -53,7 +59,54 @@ type Config struct {
|
||||||
InterpolateParams bool // Interpolate placeholders into query string
|
InterpolateParams bool // Interpolate placeholders into query string
|
||||||
MultiStatements bool // Allow multiple statements in one query
|
MultiStatements bool // Allow multiple statements in one query
|
||||||
ParseTime bool // Parse time values to time.Time
|
ParseTime bool // Parse time values to time.Time
|
||||||
Strict bool // Return warnings as errors
|
RejectReadOnly bool // Reject read-only connections
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfig creates a new Config and sets default values.
|
||||||
|
func NewConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
Collation: defaultCollation,
|
||||||
|
Loc: time.UTC,
|
||||||
|
MaxAllowedPacket: defaultMaxAllowedPacket,
|
||||||
|
AllowNativePasswords: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) normalize() error {
|
||||||
|
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
|
||||||
|
return errInvalidDSNUnsafeCollation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default network if empty
|
||||||
|
if cfg.Net == "" {
|
||||||
|
cfg.Net = "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default address if empty
|
||||||
|
if cfg.Addr == "" {
|
||||||
|
switch cfg.Net {
|
||||||
|
case "tcp":
|
||||||
|
cfg.Addr = "127.0.0.1:3306"
|
||||||
|
case "unix":
|
||||||
|
cfg.Addr = "/tmp/mysql.sock"
|
||||||
|
default:
|
||||||
|
return errors.New("default addr for network '" + cfg.Net + "' unknown")
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if cfg.Net == "tcp" {
|
||||||
|
cfg.Addr = ensureHavePort(cfg.Addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.tls != nil {
|
||||||
|
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
|
||||||
|
host, _, err := net.SplitHostPort(cfg.Addr)
|
||||||
|
if err == nil {
|
||||||
|
cfg.tls.ServerName = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FormatDSN formats the given Config into a DSN string which can be passed to
|
// FormatDSN formats the given Config into a DSN string which can be passed to
|
||||||
|
@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.AllowNativePasswords {
|
if !cfg.AllowNativePasswords {
|
||||||
if hasParam {
|
if hasParam {
|
||||||
buf.WriteString("&allowNativePasswords=true")
|
buf.WriteString("&allowNativePasswords=false")
|
||||||
} else {
|
} else {
|
||||||
hasParam = true
|
hasParam = true
|
||||||
buf.WriteString("?allowNativePasswords=true")
|
buf.WriteString("?allowNativePasswords=false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string {
|
||||||
buf.WriteString(cfg.ReadTimeout.String())
|
buf.WriteString(cfg.ReadTimeout.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Strict {
|
if cfg.RejectReadOnly {
|
||||||
if hasParam {
|
if hasParam {
|
||||||
buf.WriteString("&strict=true")
|
buf.WriteString("&rejectReadOnly=true")
|
||||||
} else {
|
} else {
|
||||||
hasParam = true
|
hasParam = true
|
||||||
buf.WriteString("?strict=true")
|
buf.WriteString("?rejectReadOnly=true")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(cfg.ServerPubKey) > 0 {
|
||||||
|
if hasParam {
|
||||||
|
buf.WriteString("&serverPubKey=")
|
||||||
|
} else {
|
||||||
|
hasParam = true
|
||||||
|
buf.WriteString("?serverPubKey=")
|
||||||
|
}
|
||||||
|
buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.Timeout > 0 {
|
if cfg.Timeout > 0 {
|
||||||
if hasParam {
|
if hasParam {
|
||||||
buf.WriteString("&timeout=")
|
buf.WriteString("&timeout=")
|
||||||
|
@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string {
|
||||||
buf.WriteString(cfg.WriteTimeout.String())
|
buf.WriteString(cfg.WriteTimeout.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.MaxAllowedPacket > 0 {
|
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
|
||||||
if hasParam {
|
if hasParam {
|
||||||
buf.WriteString("&maxAllowedPacket=")
|
buf.WriteString("&maxAllowedPacket=")
|
||||||
} else {
|
} else {
|
||||||
|
@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string {
|
||||||
|
|
||||||
// other params
|
// other params
|
||||||
if cfg.Params != nil {
|
if cfg.Params != nil {
|
||||||
for param, value := range cfg.Params {
|
var params []string
|
||||||
|
for param := range cfg.Params {
|
||||||
|
params = append(params, param)
|
||||||
|
}
|
||||||
|
sort.Strings(params)
|
||||||
|
for _, param := range params {
|
||||||
if hasParam {
|
if hasParam {
|
||||||
buf.WriteByte('&')
|
buf.WriteByte('&')
|
||||||
} else {
|
} else {
|
||||||
|
@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string {
|
||||||
|
|
||||||
buf.WriteString(param)
|
buf.WriteString(param)
|
||||||
buf.WriteByte('=')
|
buf.WriteByte('=')
|
||||||
buf.WriteString(url.QueryEscape(value))
|
buf.WriteString(url.QueryEscape(cfg.Params[param]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string {
|
||||||
// ParseDSN parses the DSN string to a Config
|
// ParseDSN parses the DSN string to a Config
|
||||||
func ParseDSN(dsn string) (cfg *Config, err error) {
|
func ParseDSN(dsn string) (cfg *Config, err error) {
|
||||||
// New config with some default values
|
// New config with some default values
|
||||||
cfg = &Config{
|
cfg = NewConfig()
|
||||||
Loc: time.UTC,
|
|
||||||
Collation: defaultCollation,
|
|
||||||
}
|
|
||||||
|
|
||||||
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
||||||
// Find the last '/' (since the password or the net addr might contain a '/')
|
// Find the last '/' (since the password or the net addr might contain a '/')
|
||||||
|
@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
|
||||||
return nil, errInvalidDSNNoSlash
|
return nil, errInvalidDSNNoSlash
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
|
if err = cfg.normalize(); err != nil {
|
||||||
return nil, errInvalidDSNUnsafeCollation
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default network if empty
|
|
||||||
if cfg.Net == "" {
|
|
||||||
cfg.Net = "tcp"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set default address if empty
|
|
||||||
if cfg.Addr == "" {
|
|
||||||
switch cfg.Net {
|
|
||||||
case "tcp":
|
|
||||||
cfg.Addr = "127.0.0.1:3306"
|
|
||||||
case "unix":
|
|
||||||
cfg.Addr = "/tmp/mysql.sock"
|
|
||||||
default:
|
|
||||||
return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
||||||
|
|
||||||
// cfg params
|
// cfg params
|
||||||
switch value := param[1]; param[0] {
|
switch value := param[1]; param[0] {
|
||||||
|
|
||||||
// Disable INFILE whitelist / enable all files
|
// Disable INFILE whitelist / enable all files
|
||||||
case "allowAllFiles":
|
case "allowAllFiles":
|
||||||
var isBool bool
|
var isBool bool
|
||||||
|
@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strict mode
|
// Reject read-only connections
|
||||||
case "strict":
|
case "rejectReadOnly":
|
||||||
var isBool bool
|
var isBool bool
|
||||||
cfg.Strict, isBool = readBool(value)
|
cfg.RejectReadOnly, isBool = readBool(value)
|
||||||
if !isBool {
|
if !isBool {
|
||||||
return errors.New("invalid bool value: " + value)
|
return errors.New("invalid bool value: " + value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server public key
|
||||||
|
case "serverPubKey":
|
||||||
|
name, err := url.QueryUnescape(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid value for server pub key name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pubKey := getServerPubKey(name); pubKey != nil {
|
||||||
|
cfg.ServerPubKey = name
|
||||||
|
cfg.pubKey = pubKey
|
||||||
|
} else {
|
||||||
|
return errors.New("invalid value / unknown server pub key name: " + name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strict mode
|
||||||
|
case "strict":
|
||||||
|
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
|
||||||
|
|
||||||
// Dial Timeout
|
// Dial Timeout
|
||||||
case "timeout":
|
case "timeout":
|
||||||
cfg.Timeout, err = time.ParseDuration(value)
|
cfg.Timeout, err = time.ParseDuration(value)
|
||||||
|
@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
||||||
return fmt.Errorf("invalid value for TLS config name: %v", err)
|
return fmt.Errorf("invalid value for TLS config name: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tlsConfig, ok := tlsConfigRegister[name]; ok {
|
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
|
||||||
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
|
|
||||||
host, _, err := net.SplitHostPort(cfg.Addr)
|
|
||||||
if err == nil {
|
|
||||||
tlsConfig.ServerName = host
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.TLSConfig = name
|
cfg.TLSConfig = name
|
||||||
cfg.tls = tlsConfig
|
cfg.tls = tlsConfig
|
||||||
} else {
|
} else {
|
||||||
|
@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureHavePort(addr string) string {
|
||||||
|
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||||
|
return net.JoinHostPort(addr, "3306")
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
|
@ -9,10 +9,8 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
@ -31,6 +29,12 @@ var (
|
||||||
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
|
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
|
||||||
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
|
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
|
||||||
ErrBusyBuffer = errors.New("busy buffer")
|
ErrBusyBuffer = errors.New("busy buffer")
|
||||||
|
|
||||||
|
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
|
||||||
|
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
|
||||||
|
// to trigger a resend.
|
||||||
|
// See https://github.com/go-sql-driver/mysql/pull/302
|
||||||
|
errBadConnNoWrite = errors.New("bad connection")
|
||||||
)
|
)
|
||||||
|
|
||||||
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
|
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
|
||||||
|
@ -59,74 +63,3 @@ type MySQLError struct {
|
||||||
func (me *MySQLError) Error() string {
|
func (me *MySQLError) Error() string {
|
||||||
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
|
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MySQLWarnings is an error type which represents a group of one or more MySQL
|
|
||||||
// warnings
|
|
||||||
type MySQLWarnings []MySQLWarning
|
|
||||||
|
|
||||||
func (mws MySQLWarnings) Error() string {
|
|
||||||
var msg string
|
|
||||||
for i, warning := range mws {
|
|
||||||
if i > 0 {
|
|
||||||
msg += "\r\n"
|
|
||||||
}
|
|
||||||
msg += fmt.Sprintf(
|
|
||||||
"%s %s: %s",
|
|
||||||
warning.Level,
|
|
||||||
warning.Code,
|
|
||||||
warning.Message,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// MySQLWarning is an error type which represents a single MySQL warning.
|
|
||||||
// Warnings are returned in groups only. See MySQLWarnings
|
|
||||||
type MySQLWarning struct {
|
|
||||||
Level string
|
|
||||||
Code string
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mysqlConn) getWarnings() (err error) {
|
|
||||||
rows, err := mc.Query("SHOW WARNINGS", nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var warnings = MySQLWarnings{}
|
|
||||||
var values = make([]driver.Value, 3)
|
|
||||||
|
|
||||||
for {
|
|
||||||
err = rows.Next(values)
|
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
warning := MySQLWarning{}
|
|
||||||
|
|
||||||
if raw, ok := values[0].([]byte); ok {
|
|
||||||
warning.Level = string(raw)
|
|
||||||
} else {
|
|
||||||
warning.Level = fmt.Sprintf("%s", values[0])
|
|
||||||
}
|
|
||||||
if raw, ok := values[1].([]byte); ok {
|
|
||||||
warning.Code = string(raw)
|
|
||||||
} else {
|
|
||||||
warning.Code = fmt.Sprintf("%s", values[1])
|
|
||||||
}
|
|
||||||
if raw, ok := values[2].([]byte); ok {
|
|
||||||
warning.Message = string(raw)
|
|
||||||
} else {
|
|
||||||
warning.Message = fmt.Sprintf("%s", values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
warnings = append(warnings, warning)
|
|
||||||
|
|
||||||
case io.EOF:
|
|
||||||
return warnings
|
|
||||||
|
|
||||||
default:
|
|
||||||
rows.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||||
|
//
|
||||||
|
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (mf *mysqlField) typeDatabaseName() string {
|
||||||
|
switch mf.fieldType {
|
||||||
|
case fieldTypeBit:
|
||||||
|
return "BIT"
|
||||||
|
case fieldTypeBLOB:
|
||||||
|
if mf.charSet != collations[binaryCollation] {
|
||||||
|
return "TEXT"
|
||||||
|
}
|
||||||
|
return "BLOB"
|
||||||
|
case fieldTypeDate:
|
||||||
|
return "DATE"
|
||||||
|
case fieldTypeDateTime:
|
||||||
|
return "DATETIME"
|
||||||
|
case fieldTypeDecimal:
|
||||||
|
return "DECIMAL"
|
||||||
|
case fieldTypeDouble:
|
||||||
|
return "DOUBLE"
|
||||||
|
case fieldTypeEnum:
|
||||||
|
return "ENUM"
|
||||||
|
case fieldTypeFloat:
|
||||||
|
return "FLOAT"
|
||||||
|
case fieldTypeGeometry:
|
||||||
|
return "GEOMETRY"
|
||||||
|
case fieldTypeInt24:
|
||||||
|
return "MEDIUMINT"
|
||||||
|
case fieldTypeJSON:
|
||||||
|
return "JSON"
|
||||||
|
case fieldTypeLong:
|
||||||
|
return "INT"
|
||||||
|
case fieldTypeLongBLOB:
|
||||||
|
if mf.charSet != collations[binaryCollation] {
|
||||||
|
return "LONGTEXT"
|
||||||
|
}
|
||||||
|
return "LONGBLOB"
|
||||||
|
case fieldTypeLongLong:
|
||||||
|
return "BIGINT"
|
||||||
|
case fieldTypeMediumBLOB:
|
||||||
|
if mf.charSet != collations[binaryCollation] {
|
||||||
|
return "MEDIUMTEXT"
|
||||||
|
}
|
||||||
|
return "MEDIUMBLOB"
|
||||||
|
case fieldTypeNewDate:
|
||||||
|
return "DATE"
|
||||||
|
case fieldTypeNewDecimal:
|
||||||
|
return "DECIMAL"
|
||||||
|
case fieldTypeNULL:
|
||||||
|
return "NULL"
|
||||||
|
case fieldTypeSet:
|
||||||
|
return "SET"
|
||||||
|
case fieldTypeShort:
|
||||||
|
return "SMALLINT"
|
||||||
|
case fieldTypeString:
|
||||||
|
if mf.charSet == collations[binaryCollation] {
|
||||||
|
return "BINARY"
|
||||||
|
}
|
||||||
|
return "CHAR"
|
||||||
|
case fieldTypeTime:
|
||||||
|
return "TIME"
|
||||||
|
case fieldTypeTimestamp:
|
||||||
|
return "TIMESTAMP"
|
||||||
|
case fieldTypeTiny:
|
||||||
|
return "TINYINT"
|
||||||
|
case fieldTypeTinyBLOB:
|
||||||
|
if mf.charSet != collations[binaryCollation] {
|
||||||
|
return "TINYTEXT"
|
||||||
|
}
|
||||||
|
return "TINYBLOB"
|
||||||
|
case fieldTypeVarChar:
|
||||||
|
if mf.charSet == collations[binaryCollation] {
|
||||||
|
return "VARBINARY"
|
||||||
|
}
|
||||||
|
return "VARCHAR"
|
||||||
|
case fieldTypeVarString:
|
||||||
|
if mf.charSet == collations[binaryCollation] {
|
||||||
|
return "VARBINARY"
|
||||||
|
}
|
||||||
|
return "VARCHAR"
|
||||||
|
case fieldTypeYear:
|
||||||
|
return "YEAR"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
scanTypeFloat32 = reflect.TypeOf(float32(0))
|
||||||
|
scanTypeFloat64 = reflect.TypeOf(float64(0))
|
||||||
|
scanTypeInt8 = reflect.TypeOf(int8(0))
|
||||||
|
scanTypeInt16 = reflect.TypeOf(int16(0))
|
||||||
|
scanTypeInt32 = reflect.TypeOf(int32(0))
|
||||||
|
scanTypeInt64 = reflect.TypeOf(int64(0))
|
||||||
|
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
|
||||||
|
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
|
||||||
|
scanTypeNullTime = reflect.TypeOf(NullTime{})
|
||||||
|
scanTypeUint8 = reflect.TypeOf(uint8(0))
|
||||||
|
scanTypeUint16 = reflect.TypeOf(uint16(0))
|
||||||
|
scanTypeUint32 = reflect.TypeOf(uint32(0))
|
||||||
|
scanTypeUint64 = reflect.TypeOf(uint64(0))
|
||||||
|
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
|
||||||
|
scanTypeUnknown = reflect.TypeOf(new(interface{}))
|
||||||
|
)
|
||||||
|
|
||||||
|
type mysqlField struct {
|
||||||
|
tableName string
|
||||||
|
name string
|
||||||
|
length uint32
|
||||||
|
flags fieldFlag
|
||||||
|
fieldType fieldType
|
||||||
|
decimals byte
|
||||||
|
charSet uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mf *mysqlField) scanType() reflect.Type {
|
||||||
|
switch mf.fieldType {
|
||||||
|
case fieldTypeTiny:
|
||||||
|
if mf.flags&flagNotNULL != 0 {
|
||||||
|
if mf.flags&flagUnsigned != 0 {
|
||||||
|
return scanTypeUint8
|
||||||
|
}
|
||||||
|
return scanTypeInt8
|
||||||
|
}
|
||||||
|
return scanTypeNullInt
|
||||||
|
|
||||||
|
case fieldTypeShort, fieldTypeYear:
|
||||||
|
if mf.flags&flagNotNULL != 0 {
|
||||||
|
if mf.flags&flagUnsigned != 0 {
|
||||||
|
return scanTypeUint16
|
||||||
|
}
|
||||||
|
return scanTypeInt16
|
||||||
|
}
|
||||||
|
return scanTypeNullInt
|
||||||
|
|
||||||
|
case fieldTypeInt24, fieldTypeLong:
|
||||||
|
if mf.flags&flagNotNULL != 0 {
|
||||||
|
if mf.flags&flagUnsigned != 0 {
|
||||||
|
return scanTypeUint32
|
||||||
|
}
|
||||||
|
return scanTypeInt32
|
||||||
|
}
|
||||||
|
return scanTypeNullInt
|
||||||
|
|
||||||
|
case fieldTypeLongLong:
|
||||||
|
if mf.flags&flagNotNULL != 0 {
|
||||||
|
if mf.flags&flagUnsigned != 0 {
|
||||||
|
return scanTypeUint64
|
||||||
|
}
|
||||||
|
return scanTypeInt64
|
||||||
|
}
|
||||||
|
return scanTypeNullInt
|
||||||
|
|
||||||
|
case fieldTypeFloat:
|
||||||
|
if mf.flags&flagNotNULL != 0 {
|
||||||
|
return scanTypeFloat32
|
||||||
|
}
|
||||||
|
return scanTypeNullFloat
|
||||||
|
|
||||||
|
case fieldTypeDouble:
|
||||||
|
if mf.flags&flagNotNULL != 0 {
|
||||||
|
return scanTypeFloat64
|
||||||
|
}
|
||||||
|
return scanTypeNullFloat
|
||||||
|
|
||||||
|
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
|
||||||
|
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
|
||||||
|
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
|
||||||
|
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
|
||||||
|
fieldTypeTime:
|
||||||
|
return scanTypeRawBytes
|
||||||
|
|
||||||
|
case fieldTypeDate, fieldTypeNewDate,
|
||||||
|
fieldTypeTimestamp, fieldTypeDateTime:
|
||||||
|
// NullTime is always returned for more consistent behavior as it can
|
||||||
|
// handle both cases of parseTime regardless if the field is nullable.
|
||||||
|
return scanTypeNullTime
|
||||||
|
|
||||||
|
default:
|
||||||
|
return scanTypeUnknown
|
||||||
|
}
|
||||||
|
}
|
|
@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// send content packets
|
// send content packets
|
||||||
if err == nil {
|
// if packetSize == 0, the Reader contains no data
|
||||||
|
if err == nil && packetSize > 0 {
|
||||||
data := make([]byte, 4+packetSize)
|
data := make([]byte, 4+packetSize)
|
||||||
var n int
|
var n int
|
||||||
for err == nil {
|
for err == nil {
|
||||||
|
@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
||||||
|
|
||||||
// read OK packet
|
// read OK packet
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = mc.readResultOK()
|
return mc.readResultOK()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mc.readPacket()
|
mc.readPacket()
|
||||||
|
|
|
@ -25,26 +25,23 @@ import (
|
||||||
|
|
||||||
// Read packet to buffer 'data'
|
// Read packet to buffer 'data'
|
||||||
func (mc *mysqlConn) readPacket() ([]byte, error) {
|
func (mc *mysqlConn) readPacket() ([]byte, error) {
|
||||||
var payload []byte
|
var prevData []byte
|
||||||
for {
|
for {
|
||||||
// Read packet header
|
// read packet header
|
||||||
data, err := mc.buf.readNext(4)
|
data, err := mc.buf.readNext(4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if cerr := mc.canceled.Value(); cerr != nil {
|
||||||
|
return nil, cerr
|
||||||
|
}
|
||||||
errLog.Print(err)
|
errLog.Print(err)
|
||||||
mc.Close()
|
mc.Close()
|
||||||
return nil, driver.ErrBadConn
|
return nil, ErrInvalidConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Packet Length [24 bit]
|
// packet length [24 bit]
|
||||||
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
|
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
|
||||||
|
|
||||||
if pktLen < 1 {
|
// check packet sync [8 bit]
|
||||||
errLog.Print(ErrMalformPkt)
|
|
||||||
mc.Close()
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check Packet Sync [8 bit]
|
|
||||||
if data[3] != mc.sequence {
|
if data[3] != mc.sequence {
|
||||||
if data[3] > mc.sequence {
|
if data[3] > mc.sequence {
|
||||||
return nil, ErrPktSyncMul
|
return nil, ErrPktSyncMul
|
||||||
|
@ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
||||||
}
|
}
|
||||||
mc.sequence++
|
mc.sequence++
|
||||||
|
|
||||||
// Read packet body [pktLen bytes]
|
// packets with length 0 terminate a previous packet which is a
|
||||||
data, err = mc.buf.readNext(pktLen)
|
// multiple of (2^24)−1 bytes long
|
||||||
if err != nil {
|
if pktLen == 0 {
|
||||||
errLog.Print(err)
|
// there was no previous packet
|
||||||
|
if prevData == nil {
|
||||||
|
errLog.Print(ErrMalformPkt)
|
||||||
mc.Close()
|
mc.Close()
|
||||||
return nil, driver.ErrBadConn
|
return nil, ErrInvalidConn
|
||||||
}
|
}
|
||||||
|
|
||||||
isLastPacket := (pktLen < maxPacketSize)
|
return prevData, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Zero allocations for non-splitting packets
|
// read packet body [pktLen bytes]
|
||||||
if isLastPacket && payload == nil {
|
data, err = mc.buf.readNext(pktLen)
|
||||||
|
if err != nil {
|
||||||
|
if cerr := mc.canceled.Value(); cerr != nil {
|
||||||
|
return nil, cerr
|
||||||
|
}
|
||||||
|
errLog.Print(err)
|
||||||
|
mc.Close()
|
||||||
|
return nil, ErrInvalidConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// return data if this was the last packet
|
||||||
|
if pktLen < maxPacketSize {
|
||||||
|
// zero allocations for non-split packets
|
||||||
|
if prevData == nil {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = append(payload, data...)
|
return append(prevData, data...), nil
|
||||||
|
|
||||||
if isLastPacket {
|
|
||||||
return payload, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prevData = append(prevData, data...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error {
|
||||||
|
|
||||||
// Handle error
|
// Handle error
|
||||||
if err == nil { // n != len(data)
|
if err == nil { // n != len(data)
|
||||||
|
mc.cleanup()
|
||||||
errLog.Print(ErrMalformPkt)
|
errLog.Print(ErrMalformPkt)
|
||||||
} else {
|
} else {
|
||||||
|
if cerr := mc.canceled.Value(); cerr != nil {
|
||||||
|
return cerr
|
||||||
|
}
|
||||||
|
if n == 0 && pktLen == len(data)-4 {
|
||||||
|
// only for the first loop iteration when nothing was written yet
|
||||||
|
return errBadConnNoWrite
|
||||||
|
}
|
||||||
|
mc.cleanup()
|
||||||
errLog.Print(err)
|
errLog.Print(err)
|
||||||
}
|
}
|
||||||
return driver.ErrBadConn
|
return ErrInvalidConn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/******************************************************************************
|
/******************************************************************************
|
||||||
* Initialisation Process *
|
* Initialization Process *
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
// Handshake Initialization Packet
|
// Handshake Initialization Packet
|
||||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||||
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
|
||||||
data, err := mc.readPacket()
|
data, err := mc.readPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
|
||||||
|
// in connection initialization we don't risk retrying non-idempotent actions.
|
||||||
|
if err == ErrInvalidConn {
|
||||||
|
return nil, "", driver.ErrBadConn
|
||||||
|
}
|
||||||
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if data[0] == iERR {
|
if data[0] == iERR {
|
||||||
return nil, mc.handleErrorPacket(data)
|
return nil, "", mc.handleErrorPacket(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// protocol version [1 byte]
|
// protocol version [1 byte]
|
||||||
if data[0] < minProtocolVersion {
|
if data[0] < minProtocolVersion {
|
||||||
return nil, fmt.Errorf(
|
return nil, "", fmt.Errorf(
|
||||||
"unsupported protocol version %d. Version %d or higher is required",
|
"unsupported protocol version %d. Version %d or higher is required",
|
||||||
data[0],
|
data[0],
|
||||||
minProtocolVersion,
|
minProtocolVersion,
|
||||||
|
@ -157,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
||||||
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
|
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
|
||||||
|
|
||||||
// first part of the password cipher [8 bytes]
|
// first part of the password cipher [8 bytes]
|
||||||
cipher := data[pos : pos+8]
|
authData := data[pos : pos+8]
|
||||||
|
|
||||||
// (filler) always 0x00 [1 byte]
|
// (filler) always 0x00 [1 byte]
|
||||||
pos += 8 + 1
|
pos += 8 + 1
|
||||||
|
@ -165,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
||||||
// capability flags (lower 2 bytes) [2 bytes]
|
// capability flags (lower 2 bytes) [2 bytes]
|
||||||
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||||
if mc.flags&clientProtocol41 == 0 {
|
if mc.flags&clientProtocol41 == 0 {
|
||||||
return nil, ErrOldProtocol
|
return nil, "", ErrOldProtocol
|
||||||
}
|
}
|
||||||
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
|
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
|
||||||
return nil, ErrNoTLS
|
return nil, "", ErrNoTLS
|
||||||
}
|
}
|
||||||
pos += 2
|
pos += 2
|
||||||
|
|
||||||
|
plugin := ""
|
||||||
if len(data) > pos {
|
if len(data) > pos {
|
||||||
// character set [1 byte]
|
// character set [1 byte]
|
||||||
// status flags [2 bytes]
|
// status flags [2 bytes]
|
||||||
|
@ -192,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
||||||
//
|
//
|
||||||
// The official Python library uses the fixed length 12
|
// The official Python library uses the fixed length 12
|
||||||
// which seems to work but technically could have a hidden bug.
|
// which seems to work but technically could have a hidden bug.
|
||||||
cipher = append(cipher, data[pos:pos+12]...)
|
authData = append(authData, data[pos:pos+12]...)
|
||||||
|
pos += 13
|
||||||
|
|
||||||
// TODO: Verify string termination
|
|
||||||
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
|
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
|
||||||
// \NUL otherwise
|
// \NUL otherwise
|
||||||
//
|
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
|
||||||
//if data[len(data)-1] == 0 {
|
plugin = string(data[pos : pos+end])
|
||||||
// return
|
} else {
|
||||||
//}
|
plugin = string(data[pos:])
|
||||||
//return ErrMalformPkt
|
|
||||||
|
|
||||||
// make a memory safe copy of the cipher slice
|
|
||||||
var b [20]byte
|
|
||||||
copy(b[:], cipher)
|
|
||||||
return b[:], nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make a memory safe copy of the cipher slice
|
||||||
|
var b [20]byte
|
||||||
|
copy(b[:], authData)
|
||||||
|
return b[:], plugin, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
plugin = defaultAuthPlugin
|
||||||
|
|
||||||
// make a memory safe copy of the cipher slice
|
// make a memory safe copy of the cipher slice
|
||||||
var b [8]byte
|
var b [8]byte
|
||||||
copy(b[:], cipher)
|
copy(b[:], authData)
|
||||||
return b[:], nil
|
return b[:], plugin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client Authentication Packet
|
// Client Authentication Packet
|
||||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||||
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
|
||||||
// Adjust client flags based on server support
|
// Adjust client flags based on server support
|
||||||
clientFlags := clientProtocol41 |
|
clientFlags := clientProtocol41 |
|
||||||
clientSecureConn |
|
clientSecureConn |
|
||||||
|
@ -241,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
||||||
clientFlags |= clientMultiStatements
|
clientFlags |= clientMultiStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
// User Password
|
// encode length of the auth plugin data
|
||||||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
var authRespLEIBuf [9]byte
|
||||||
|
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
|
||||||
|
if len(authRespLEI) > 1 {
|
||||||
|
// if the length can not be written in 1 byte, it must be written as a
|
||||||
|
// length encoded integer
|
||||||
|
clientFlags |= clientPluginAuthLenEncClientData
|
||||||
|
}
|
||||||
|
|
||||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
|
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
|
||||||
|
if addNUL {
|
||||||
|
pktLen++
|
||||||
|
}
|
||||||
|
|
||||||
// To specify a db name
|
// To specify a db name
|
||||||
if n := len(mc.cfg.DBName); n > 0 {
|
if n := len(mc.cfg.DBName); n > 0 {
|
||||||
|
@ -255,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
||||||
// Calculate packet length and get buffer with that size
|
// Calculate packet length and get buffer with that size
|
||||||
data := mc.buf.takeSmallBuffer(pktLen + 4)
|
data := mc.buf.takeSmallBuffer(pktLen + 4)
|
||||||
if data == nil {
|
if data == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// cannot take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return driver.ErrBadConn
|
return errBadConnNoWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientFlags [32 bit]
|
// ClientFlags [32 bit]
|
||||||
|
@ -312,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
||||||
data[pos] = 0x00
|
data[pos] = 0x00
|
||||||
pos++
|
pos++
|
||||||
|
|
||||||
// ScrambleBuffer [length encoded integer]
|
// Auth Data [length encoded integer]
|
||||||
data[pos] = byte(len(scrambleBuff))
|
pos += copy(data[pos:], authRespLEI)
|
||||||
pos += 1 + copy(data[pos+1:], scrambleBuff)
|
pos += copy(data[pos:], authResp)
|
||||||
|
if addNUL {
|
||||||
|
data[pos] = 0x00
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
|
||||||
// Databasename [null terminated string]
|
// Databasename [null terminated string]
|
||||||
if len(mc.cfg.DBName) > 0 {
|
if len(mc.cfg.DBName) > 0 {
|
||||||
|
@ -323,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
||||||
pos++
|
pos++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assume native client during response
|
pos += copy(data[pos:], plugin)
|
||||||
pos += copy(data[pos:], "mysql_native_password")
|
|
||||||
data[pos] = 0x00
|
data[pos] = 0x00
|
||||||
|
|
||||||
// Send Auth packet
|
// Send Auth packet
|
||||||
return mc.writePacket(data)
|
return mc.writePacket(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client old authentication packet
|
|
||||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||||
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
|
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
|
||||||
// User password
|
pktLen := 4 + len(authData)
|
||||||
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
|
if addNUL {
|
||||||
|
pktLen++
|
||||||
// Calculate the packet length and add a tailing 0
|
}
|
||||||
pktLen := len(scrambleBuff) + 1
|
data := mc.buf.takeSmallBuffer(pktLen)
|
||||||
data := mc.buf.takeSmallBuffer(4 + pktLen)
|
|
||||||
if data == nil {
|
if data == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// cannot take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return driver.ErrBadConn
|
return errBadConnNoWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the scrambled password [null terminated string]
|
// Add the auth data [EOF]
|
||||||
copy(data[4:], scrambleBuff)
|
copy(data[4:], authData)
|
||||||
data[4+pktLen-1] = 0x00
|
if addNUL {
|
||||||
|
data[pktLen-1] = 0x00
|
||||||
return mc.writePacket(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client clear text authentication packet
|
|
||||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
||||||
func (mc *mysqlConn) writeClearAuthPacket() error {
|
|
||||||
// Calculate the packet length and add a tailing 0
|
|
||||||
pktLen := len(mc.cfg.Passwd) + 1
|
|
||||||
data := mc.buf.takeSmallBuffer(4 + pktLen)
|
|
||||||
if data == nil {
|
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
|
||||||
errLog.Print(ErrBusyBuffer)
|
|
||||||
return driver.ErrBadConn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the clear password [null terminated string]
|
|
||||||
copy(data[4:], mc.cfg.Passwd)
|
|
||||||
data[4+pktLen-1] = 0x00
|
|
||||||
|
|
||||||
return mc.writePacket(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Native password authentication method
|
|
||||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
||||||
func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
|
|
||||||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
|
||||||
|
|
||||||
// Calculate the packet length and add a tailing 0
|
|
||||||
pktLen := len(scrambleBuff)
|
|
||||||
data := mc.buf.takeSmallBuffer(4 + pktLen)
|
|
||||||
if data == nil {
|
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
|
||||||
errLog.Print(ErrBusyBuffer)
|
|
||||||
return driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the scramble
|
|
||||||
copy(data[4:], scrambleBuff)
|
|
||||||
|
|
||||||
return mc.writePacket(data)
|
return mc.writePacket(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -402,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
||||||
|
|
||||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||||
if data == nil {
|
if data == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// cannot take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return driver.ErrBadConn
|
return errBadConnNoWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add command byte
|
// Add command byte
|
||||||
|
@ -421,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
||||||
pktLen := 1 + len(arg)
|
pktLen := 1 + len(arg)
|
||||||
data := mc.buf.takeBuffer(pktLen + 4)
|
data := mc.buf.takeBuffer(pktLen + 4)
|
||||||
if data == nil {
|
if data == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// cannot take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return driver.ErrBadConn
|
return errBadConnNoWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add command byte
|
// Add command byte
|
||||||
|
@ -442,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
||||||
|
|
||||||
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
||||||
if data == nil {
|
if data == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// cannot take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return driver.ErrBadConn
|
return errBadConnNoWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add command byte
|
// Add command byte
|
||||||
|
@ -464,43 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
||||||
* Result Packets *
|
* Result Packets *
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
// Returns error if Packet is not an 'Result OK'-Packet
|
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
|
||||||
func (mc *mysqlConn) readResultOK() ([]byte, error) {
|
|
||||||
data, err := mc.readPacket()
|
data, err := mc.readPacket()
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
// packet indicator
|
// packet indicator
|
||||||
switch data[0] {
|
switch data[0] {
|
||||||
|
|
||||||
case iOK:
|
case iOK:
|
||||||
return nil, mc.handleOkPacket(data)
|
return nil, "", mc.handleOkPacket(data)
|
||||||
|
|
||||||
|
case iAuthMoreData:
|
||||||
|
return data[1:], "", err
|
||||||
|
|
||||||
case iEOF:
|
case iEOF:
|
||||||
if len(data) > 1 {
|
if len(data) < 1 {
|
||||||
|
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
||||||
|
return nil, "mysql_old_password", nil
|
||||||
|
}
|
||||||
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
||||||
|
if pluginEndIndex < 0 {
|
||||||
|
return nil, "", ErrMalformPkt
|
||||||
|
}
|
||||||
plugin := string(data[1:pluginEndIndex])
|
plugin := string(data[1:pluginEndIndex])
|
||||||
cipher := data[pluginEndIndex+1 : len(data)-1]
|
authData := data[pluginEndIndex+1:]
|
||||||
|
return authData, plugin, nil
|
||||||
if plugin == "mysql_old_password" {
|
|
||||||
// using old_passwords
|
|
||||||
return cipher, ErrOldPassword
|
|
||||||
} else if plugin == "mysql_clear_password" {
|
|
||||||
// using clear text password
|
|
||||||
return cipher, ErrCleartextPassword
|
|
||||||
} else if plugin == "mysql_native_password" {
|
|
||||||
// using mysql default authentication method
|
|
||||||
return cipher, ErrNativePassword
|
|
||||||
} else {
|
|
||||||
return cipher, ErrUnknownPlugin
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, ErrOldPassword
|
|
||||||
}
|
|
||||||
|
|
||||||
default: // Error otherwise
|
default: // Error otherwise
|
||||||
return nil, mc.handleErrorPacket(data)
|
return nil, "", mc.handleErrorPacket(data)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns error if Packet is not an 'Result OK'-Packet
|
||||||
|
func (mc *mysqlConn) readResultOK() error {
|
||||||
|
data, err := mc.readPacket()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil, err
|
|
||||||
|
if data[0] == iOK {
|
||||||
|
return mc.handleOkPacket(data)
|
||||||
|
}
|
||||||
|
return mc.handleErrorPacket(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Result Set Header Packet
|
// Result Set Header Packet
|
||||||
|
@ -543,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
||||||
// Error Number [16 bit uint]
|
// Error Number [16 bit uint]
|
||||||
errno := binary.LittleEndian.Uint16(data[1:3])
|
errno := binary.LittleEndian.Uint16(data[1:3])
|
||||||
|
|
||||||
|
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
|
||||||
|
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
|
||||||
|
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
|
||||||
|
// Oops; we are connected to a read-only connection, and won't be able
|
||||||
|
// to issue any write statements. Since RejectReadOnly is configured,
|
||||||
|
// we throw away this connection hoping this one would have write
|
||||||
|
// permission. This is specifically for a possible race condition
|
||||||
|
// during failover (e.g. on AWS Aurora). See README.md for more.
|
||||||
|
//
|
||||||
|
// We explicitly close the connection before returning
|
||||||
|
// driver.ErrBadConn to ensure that `database/sql` purges this
|
||||||
|
// connection and initiates a new one for next statement next time.
|
||||||
|
mc.Close()
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
pos := 3
|
pos := 3
|
||||||
|
|
||||||
// SQL State [optional: # + 5bytes string]
|
// SQL State [optional: # + 5bytes string]
|
||||||
|
@ -577,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
|
||||||
|
|
||||||
// server_status [2 bytes]
|
// server_status [2 bytes]
|
||||||
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
||||||
if err := mc.discardResults(); err != nil {
|
if mc.status&statusMoreResultsExists != 0 {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// warning count [2 bytes]
|
|
||||||
if !mc.strict {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pos := 1 + n + m + 2
|
// warning count [2 bytes]
|
||||||
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
|
|
||||||
return mc.getWarnings()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -661,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
pos += n
|
||||||
|
|
||||||
// Filler [uint8]
|
// Filler [uint8]
|
||||||
|
pos++
|
||||||
|
|
||||||
// Charset [charset, collation uint8]
|
// Charset [charset, collation uint8]
|
||||||
|
columns[i].charSet = data[pos]
|
||||||
|
pos += 2
|
||||||
|
|
||||||
// Length [uint32]
|
// Length [uint32]
|
||||||
pos += n + 1 + 2 + 4
|
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
|
||||||
|
pos += 4
|
||||||
|
|
||||||
// Field type [uint8]
|
// Field type [uint8]
|
||||||
columns[i].fieldType = data[pos]
|
columns[i].fieldType = fieldType(data[pos])
|
||||||
pos++
|
pos++
|
||||||
|
|
||||||
// Flags [uint16]
|
// Flags [uint16]
|
||||||
|
@ -691,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
||||||
func (rows *textRows) readRow(dest []driver.Value) error {
|
func (rows *textRows) readRow(dest []driver.Value) error {
|
||||||
mc := rows.mc
|
mc := rows.mc
|
||||||
|
|
||||||
|
if rows.rs.done {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
data, err := mc.readPacket()
|
data, err := mc.readPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -700,10 +729,10 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
||||||
if data[0] == iEOF && len(data) == 5 {
|
if data[0] == iEOF && len(data) == 5 {
|
||||||
// server_status [2 bytes]
|
// server_status [2 bytes]
|
||||||
rows.mc.status = readStatus(data[3:])
|
rows.mc.status = readStatus(data[3:])
|
||||||
if err := rows.mc.discardResults(); err != nil {
|
rows.rs.done = true
|
||||||
return err
|
if !rows.HasNextResultSet() {
|
||||||
}
|
|
||||||
rows.mc = nil
|
rows.mc = nil
|
||||||
|
}
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
if data[0] == iERR {
|
if data[0] == iERR {
|
||||||
|
@ -725,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
||||||
if !mc.parseTime {
|
if !mc.parseTime {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
switch rows.columns[i].fieldType {
|
switch rows.rs.columns[i].fieldType {
|
||||||
case fieldTypeTimestamp, fieldTypeDateTime,
|
case fieldTypeTimestamp, fieldTypeDateTime,
|
||||||
fieldTypeDate, fieldTypeNewDate:
|
fieldTypeDate, fieldTypeNewDate:
|
||||||
dest[i], err = parseDateTime(
|
dest[i], err = parseDateTime(
|
||||||
|
@ -797,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
||||||
// Reserved [8 bit]
|
// Reserved [8 bit]
|
||||||
|
|
||||||
// Warning count [16 bit uint]
|
// Warning count [16 bit uint]
|
||||||
if !stmt.mc.strict {
|
|
||||||
return columnCount, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for warnings count > 0, only available in MySQL > 4.1
|
|
||||||
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
|
|
||||||
return columnCount, stmt.mc.getWarnings()
|
|
||||||
}
|
|
||||||
return columnCount, nil
|
return columnCount, nil
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -821,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
|
||||||
// 2 bytes paramID
|
// 2 bytes paramID
|
||||||
const dataOffset = 1 + 4 + 2
|
const dataOffset = 1 + 4 + 2
|
||||||
|
|
||||||
// Can not use the write buffer since
|
// Cannot use the write buffer since
|
||||||
// a) the buffer is too small
|
// a) the buffer is too small
|
||||||
// b) it is in use
|
// b) it is in use
|
||||||
data := make([]byte, 4+1+4+2+len(arg))
|
data := make([]byte, 4+1+4+2+len(arg))
|
||||||
|
@ -876,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
const minPktLen = 4 + 1 + 4 + 1 + 4
|
const minPktLen = 4 + 1 + 4 + 1 + 4
|
||||||
mc := stmt.mc
|
mc := stmt.mc
|
||||||
|
|
||||||
|
// Determine threshould dynamically to avoid packet size shortage.
|
||||||
|
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
||||||
|
if longDataSize < 64 {
|
||||||
|
longDataSize = 64
|
||||||
|
}
|
||||||
|
|
||||||
// Reset packet-sequence
|
// Reset packet-sequence
|
||||||
mc.sequence = 0
|
mc.sequence = 0
|
||||||
|
|
||||||
|
@ -887,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
data = mc.buf.takeCompleteBuffer()
|
data = mc.buf.takeCompleteBuffer()
|
||||||
}
|
}
|
||||||
if data == nil {
|
if data == nil {
|
||||||
// can not take the buffer. Something must be wrong with the connection
|
// cannot take the buffer. Something must be wrong with the connection
|
||||||
errLog.Print(ErrBusyBuffer)
|
errLog.Print(ErrBusyBuffer)
|
||||||
return driver.ErrBadConn
|
return errBadConnNoWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
// command [1 byte]
|
// command [1 byte]
|
||||||
|
@ -948,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
// build NULL-bitmap
|
// build NULL-bitmap
|
||||||
if arg == nil {
|
if arg == nil {
|
||||||
nullMask[i/8] |= 1 << (uint(i) & 7)
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
||||||
paramTypes[i+i] = fieldTypeNULL
|
paramTypes[i+i] = byte(fieldTypeNULL)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -956,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
// cache types and values
|
// cache types and values
|
||||||
switch v := arg.(type) {
|
switch v := arg.(type) {
|
||||||
case int64:
|
case int64:
|
||||||
paramTypes[i+i] = fieldTypeLongLong
|
paramTypes[i+i] = byte(fieldTypeLongLong)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
||||||
|
@ -972,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
case float64:
|
case float64:
|
||||||
paramTypes[i+i] = fieldTypeDouble
|
paramTypes[i+i] = byte(fieldTypeDouble)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
||||||
|
@ -988,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
case bool:
|
case bool:
|
||||||
paramTypes[i+i] = fieldTypeTiny
|
paramTypes[i+i] = byte(fieldTypeTiny)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
if v {
|
if v {
|
||||||
|
@ -1000,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
case []byte:
|
case []byte:
|
||||||
// Common case (non-nil value) first
|
// Common case (non-nil value) first
|
||||||
if v != nil {
|
if v != nil {
|
||||||
paramTypes[i+i] = fieldTypeString
|
paramTypes[i+i] = byte(fieldTypeString)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
if len(v) < longDataSize {
|
||||||
paramValues = appendLengthEncodedInteger(paramValues,
|
paramValues = appendLengthEncodedInteger(paramValues,
|
||||||
uint64(len(v)),
|
uint64(len(v)),
|
||||||
)
|
)
|
||||||
|
@ -1018,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
|
|
||||||
// Handle []byte(nil) as a NULL value
|
// Handle []byte(nil) as a NULL value
|
||||||
nullMask[i/8] |= 1 << (uint(i) & 7)
|
nullMask[i/8] |= 1 << (uint(i) & 7)
|
||||||
paramTypes[i+i] = fieldTypeNULL
|
paramTypes[i+i] = byte(fieldTypeNULL)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
case string:
|
case string:
|
||||||
paramTypes[i+i] = fieldTypeString
|
paramTypes[i+i] = byte(fieldTypeString)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
if len(v) < longDataSize {
|
||||||
paramValues = appendLengthEncodedInteger(paramValues,
|
paramValues = appendLengthEncodedInteger(paramValues,
|
||||||
uint64(len(v)),
|
uint64(len(v)),
|
||||||
)
|
)
|
||||||
|
@ -1037,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
case time.Time:
|
case time.Time:
|
||||||
paramTypes[i+i] = fieldTypeString
|
paramTypes[i+i] = byte(fieldTypeString)
|
||||||
paramTypes[i+i+1] = 0x00
|
paramTypes[i+i+1] = 0x00
|
||||||
|
|
||||||
var val []byte
|
var a [64]byte
|
||||||
|
var b = a[:0]
|
||||||
|
|
||||||
if v.IsZero() {
|
if v.IsZero() {
|
||||||
val = []byte("0000-00-00")
|
b = append(b, "0000-00-00"...)
|
||||||
} else {
|
} else {
|
||||||
val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
|
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramValues = appendLengthEncodedInteger(paramValues,
|
paramValues = appendLengthEncodedInteger(paramValues,
|
||||||
uint64(len(val)),
|
uint64(len(b)),
|
||||||
)
|
)
|
||||||
paramValues = append(paramValues, val...)
|
paramValues = append(paramValues, b...)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("can not convert type: %T", arg)
|
return fmt.Errorf("cannot convert type: %T", arg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1086,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error {
|
||||||
if err := mc.readUntilEOF(); err != nil {
|
if err := mc.readUntilEOF(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
mc.status &^= statusMoreResultsExists
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -1105,16 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
// EOF Packet
|
// EOF Packet
|
||||||
if data[0] == iEOF && len(data) == 5 {
|
if data[0] == iEOF && len(data) == 5 {
|
||||||
rows.mc.status = readStatus(data[3:])
|
rows.mc.status = readStatus(data[3:])
|
||||||
if err := rows.mc.discardResults(); err != nil {
|
rows.rs.done = true
|
||||||
return err
|
if !rows.HasNextResultSet() {
|
||||||
}
|
|
||||||
rows.mc = nil
|
rows.mc = nil
|
||||||
|
}
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
mc := rows.mc
|
||||||
rows.mc = nil
|
rows.mc = nil
|
||||||
|
|
||||||
// Error otherwise
|
// Error otherwise
|
||||||
return rows.mc.handleErrorPacket(data)
|
return mc.handleErrorPacket(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
||||||
|
@ -1130,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to byte-coded string
|
// Convert to byte-coded string
|
||||||
switch rows.columns[i].fieldType {
|
switch rows.rs.columns[i].fieldType {
|
||||||
case fieldTypeNULL:
|
case fieldTypeNULL:
|
||||||
dest[i] = nil
|
dest[i] = nil
|
||||||
continue
|
continue
|
||||||
|
|
||||||
// Numeric Types
|
// Numeric Types
|
||||||
case fieldTypeTiny:
|
case fieldTypeTiny:
|
||||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||||
dest[i] = int64(data[pos])
|
dest[i] = int64(data[pos])
|
||||||
} else {
|
} else {
|
||||||
dest[i] = int64(int8(data[pos]))
|
dest[i] = int64(int8(data[pos]))
|
||||||
|
@ -1146,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case fieldTypeShort, fieldTypeYear:
|
case fieldTypeShort, fieldTypeYear:
|
||||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||||
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||||
} else {
|
} else {
|
||||||
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
||||||
|
@ -1155,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case fieldTypeInt24, fieldTypeLong:
|
case fieldTypeInt24, fieldTypeLong:
|
||||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||||
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
||||||
} else {
|
} else {
|
||||||
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
||||||
|
@ -1164,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case fieldTypeLongLong:
|
case fieldTypeLongLong:
|
||||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||||
val := binary.LittleEndian.Uint64(data[pos : pos+8])
|
val := binary.LittleEndian.Uint64(data[pos : pos+8])
|
||||||
if val > math.MaxInt64 {
|
if val > math.MaxInt64 {
|
||||||
dest[i] = uint64ToString(val)
|
dest[i] = uint64ToString(val)
|
||||||
|
@ -1178,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case fieldTypeFloat:
|
case fieldTypeFloat:
|
||||||
dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
||||||
pos += 4
|
pos += 4
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -1218,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
case isNull:
|
case isNull:
|
||||||
dest[i] = nil
|
dest[i] = nil
|
||||||
continue
|
continue
|
||||||
case rows.columns[i].fieldType == fieldTypeTime:
|
case rows.rs.columns[i].fieldType == fieldTypeTime:
|
||||||
// database/sql does not support an equivalent to TIME, return a string
|
// database/sql does not support an equivalent to TIME, return a string
|
||||||
var dstlen uint8
|
var dstlen uint8
|
||||||
switch decimals := rows.columns[i].decimals; decimals {
|
switch decimals := rows.rs.columns[i].decimals; decimals {
|
||||||
case 0x00, 0x1f:
|
case 0x00, 0x1f:
|
||||||
dstlen = 8
|
dstlen = 8
|
||||||
case 1, 2, 3, 4, 5, 6:
|
case 1, 2, 3, 4, 5, 6:
|
||||||
|
@ -1229,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"protocol error, illegal decimals value %d",
|
"protocol error, illegal decimals value %d",
|
||||||
rows.columns[i].decimals,
|
rows.rs.columns[i].decimals,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
|
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
|
||||||
|
@ -1237,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
|
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
|
||||||
default:
|
default:
|
||||||
var dstlen uint8
|
var dstlen uint8
|
||||||
if rows.columns[i].fieldType == fieldTypeDate {
|
if rows.rs.columns[i].fieldType == fieldTypeDate {
|
||||||
dstlen = 10
|
dstlen = 10
|
||||||
} else {
|
} else {
|
||||||
switch decimals := rows.columns[i].decimals; decimals {
|
switch decimals := rows.rs.columns[i].decimals; decimals {
|
||||||
case 0x00, 0x1f:
|
case 0x00, 0x1f:
|
||||||
dstlen = 19
|
dstlen = 19
|
||||||
case 1, 2, 3, 4, 5, 6:
|
case 1, 2, 3, 4, 5, 6:
|
||||||
|
@ -1248,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"protocol error, illegal decimals value %d",
|
"protocol error, illegal decimals value %d",
|
||||||
rows.columns[i].decimals,
|
rows.rs.columns[i].decimals,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1264,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||||
|
|
||||||
// Please report if this happens!
|
// Please report if this happens!
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
|
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,19 +11,20 @@ package mysql
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mysqlField struct {
|
type resultSet struct {
|
||||||
tableName string
|
columns []mysqlField
|
||||||
name string
|
columnNames []string
|
||||||
flags fieldFlag
|
done bool
|
||||||
fieldType byte
|
|
||||||
decimals byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mysqlRows struct {
|
type mysqlRows struct {
|
||||||
mc *mysqlConn
|
mc *mysqlConn
|
||||||
columns []mysqlField
|
rs resultSet
|
||||||
|
finish func()
|
||||||
}
|
}
|
||||||
|
|
||||||
type binaryRows struct {
|
type binaryRows struct {
|
||||||
|
@ -34,37 +35,86 @@ type textRows struct {
|
||||||
mysqlRows
|
mysqlRows
|
||||||
}
|
}
|
||||||
|
|
||||||
type emptyRows struct{}
|
|
||||||
|
|
||||||
func (rows *mysqlRows) Columns() []string {
|
func (rows *mysqlRows) Columns() []string {
|
||||||
columns := make([]string, len(rows.columns))
|
if rows.rs.columnNames != nil {
|
||||||
|
return rows.rs.columnNames
|
||||||
|
}
|
||||||
|
|
||||||
|
columns := make([]string, len(rows.rs.columns))
|
||||||
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
|
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
|
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
|
||||||
columns[i] = tableName + "." + rows.columns[i].name
|
columns[i] = tableName + "." + rows.rs.columns[i].name
|
||||||
} else {
|
} else {
|
||||||
columns[i] = rows.columns[i].name
|
columns[i] = rows.rs.columns[i].name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
columns[i] = rows.columns[i].name
|
columns[i] = rows.rs.columns[i].name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rows.rs.columnNames = columns
|
||||||
return columns
|
return columns
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *mysqlRows) Close() error {
|
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
|
||||||
|
return rows.rs.columns[i].typeDatabaseName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
|
||||||
|
// return int64(rows.rs.columns[i].length), true
|
||||||
|
// }
|
||||||
|
|
||||||
|
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
|
||||||
|
return rows.rs.columns[i].flags&flagNotNULL == 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
|
||||||
|
column := rows.rs.columns[i]
|
||||||
|
decimals := int64(column.decimals)
|
||||||
|
|
||||||
|
switch column.fieldType {
|
||||||
|
case fieldTypeDecimal, fieldTypeNewDecimal:
|
||||||
|
if decimals > 0 {
|
||||||
|
return int64(column.length) - 2, decimals, true
|
||||||
|
}
|
||||||
|
return int64(column.length) - 1, decimals, true
|
||||||
|
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
|
||||||
|
return decimals, decimals, true
|
||||||
|
case fieldTypeFloat, fieldTypeDouble:
|
||||||
|
if decimals == 0x1f {
|
||||||
|
return math.MaxInt64, math.MaxInt64, true
|
||||||
|
}
|
||||||
|
return math.MaxInt64, decimals, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
|
||||||
|
return rows.rs.columns[i].scanType()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rows *mysqlRows) Close() (err error) {
|
||||||
|
if f := rows.finish; f != nil {
|
||||||
|
f()
|
||||||
|
rows.finish = nil
|
||||||
|
}
|
||||||
|
|
||||||
mc := rows.mc
|
mc := rows.mc
|
||||||
if mc == nil {
|
if mc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if mc.netConn == nil {
|
if err := mc.error(); err != nil {
|
||||||
return ErrInvalidConn
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove unread packets from stream
|
// Remove unread packets from stream
|
||||||
err := mc.readUntilEOF()
|
if !rows.rs.done {
|
||||||
|
err = mc.readUntilEOF()
|
||||||
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if err = mc.discardResults(); err != nil {
|
if err = mc.discardResults(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -75,10 +125,66 @@ func (rows *mysqlRows) Close() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rows *mysqlRows) HasNextResultSet() (b bool) {
|
||||||
|
if rows.mc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rows.mc.status&statusMoreResultsExists != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rows *mysqlRows) nextResultSet() (int, error) {
|
||||||
|
if rows.mc == nil {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if err := rows.mc.error(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove unread packets from stream
|
||||||
|
if !rows.rs.done {
|
||||||
|
if err := rows.mc.readUntilEOF(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
rows.rs.done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.HasNextResultSet() {
|
||||||
|
rows.mc = nil
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
rows.rs = resultSet{}
|
||||||
|
return rows.mc.readResultSetHeaderPacket()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
|
||||||
|
for {
|
||||||
|
resLen, err := rows.nextResultSet()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resLen > 0 {
|
||||||
|
return resLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.rs.done = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rows *binaryRows) NextResultSet() error {
|
||||||
|
resLen, err := rows.nextNotEmptyResultSet()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (rows *binaryRows) Next(dest []driver.Value) error {
|
func (rows *binaryRows) Next(dest []driver.Value) error {
|
||||||
if mc := rows.mc; mc != nil {
|
if mc := rows.mc; mc != nil {
|
||||||
if mc.netConn == nil {
|
if err := mc.error(); err != nil {
|
||||||
return ErrInvalidConn
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch next row from stream
|
// Fetch next row from stream
|
||||||
|
@ -87,10 +193,20 @@ func (rows *binaryRows) Next(dest []driver.Value) error {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rows *textRows) NextResultSet() (err error) {
|
||||||
|
resLen, err := rows.nextNotEmptyResultSet()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (rows *textRows) Next(dest []driver.Value) error {
|
func (rows *textRows) Next(dest []driver.Value) error {
|
||||||
if mc := rows.mc; mc != nil {
|
if mc := rows.mc; mc != nil {
|
||||||
if mc.netConn == nil {
|
if err := mc.error(); err != nil {
|
||||||
return ErrInvalidConn
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch next row from stream
|
// Fetch next row from stream
|
||||||
|
@ -98,15 +214,3 @@ func (rows *textRows) Next(dest []driver.Value) error {
|
||||||
}
|
}
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows emptyRows) Columns() []string {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rows emptyRows) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rows emptyRows) Next(dest []driver.Value) error {
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ package mysql
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
@ -19,12 +20,14 @@ type mysqlStmt struct {
|
||||||
mc *mysqlConn
|
mc *mysqlConn
|
||||||
id uint32
|
id uint32
|
||||||
paramCount int
|
paramCount int
|
||||||
columns []mysqlField // cached from the first query
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *mysqlStmt) Close() error {
|
func (stmt *mysqlStmt) Close() error {
|
||||||
if stmt.mc == nil || stmt.mc.netConn == nil {
|
if stmt.mc == nil || stmt.mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
// driver.Stmt.Close can be called more than once, thus this function
|
||||||
|
// has to be idempotent.
|
||||||
|
// See also Issue #450 and golang/go#16019.
|
||||||
|
//errLog.Print(ErrInvalidConn)
|
||||||
return driver.ErrBadConn
|
return driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
if stmt.mc.netConn == nil {
|
if stmt.mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
errLog.Print(ErrInvalidConn)
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
// Send command
|
// Send command
|
||||||
err := stmt.writeExecutePacket(args)
|
err := stmt.writeExecutePacket(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, stmt.mc.markBadConn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mc := stmt.mc
|
mc := stmt.mc
|
||||||
|
@ -59,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
|
|
||||||
// Read Result
|
// Read Result
|
||||||
resLen, err := mc.readResultSetHeaderPacket()
|
resLen, err := mc.readResultSetHeaderPacket()
|
||||||
if err == nil {
|
|
||||||
if resLen > 0 {
|
|
||||||
// Columns
|
|
||||||
err = mc.readUntilEOF()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rows
|
if resLen > 0 {
|
||||||
err = mc.readUntilEOF()
|
// Columns
|
||||||
|
if err = mc.readUntilEOF(); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
if err == nil {
|
|
||||||
|
// Rows
|
||||||
|
if err := mc.readUntilEOF(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mc.discardResults(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &mysqlResult{
|
return &mysqlResult{
|
||||||
affectedRows: int64(mc.affectedRows),
|
affectedRows: int64(mc.affectedRows),
|
||||||
insertId: int64(mc.insertId),
|
insertId: int64(mc.insertId),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
if stmt.mc.netConn == nil {
|
return stmt.query(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||||
|
if stmt.mc.closed.IsSet() {
|
||||||
errLog.Print(ErrInvalidConn)
|
errLog.Print(ErrInvalidConn)
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
// Send command
|
// Send command
|
||||||
err := stmt.writeExecutePacket(args)
|
err := stmt.writeExecutePacket(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, stmt.mc.markBadConn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mc := stmt.mc
|
mc := stmt.mc
|
||||||
|
@ -104,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
|
|
||||||
if resLen > 0 {
|
if resLen > 0 {
|
||||||
rows.mc = mc
|
rows.mc = mc
|
||||||
// Columns
|
rows.rs.columns, err = mc.readColumns(resLen)
|
||||||
// If not cached, read them and cache them
|
|
||||||
if stmt.columns == nil {
|
|
||||||
rows.columns, err = mc.readColumns(resLen)
|
|
||||||
stmt.columns = rows.columns
|
|
||||||
} else {
|
} else {
|
||||||
rows.columns = stmt.columns
|
rows.rs.done = true
|
||||||
err = mc.readUntilEOF()
|
|
||||||
|
switch err := rows.NextResultSet(); err {
|
||||||
|
case nil, io.EOF:
|
||||||
|
return rows, nil
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
|
|
||||||
type converter struct{}
|
type converter struct{}
|
||||||
|
|
||||||
|
// ConvertValue mirrors the reference/default converter in database/sql/driver
|
||||||
|
// with _one_ exception. We support uint64 with their high bit and the default
|
||||||
|
// implementation does not. This function should be kept in sync with
|
||||||
|
// database/sql/driver defaultConverter.ConvertValue() except for that
|
||||||
|
// deliberate difference.
|
||||||
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
||||||
if driver.IsValue(v) {
|
if driver.IsValue(v) {
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if vr, ok := v.(driver.Valuer); ok {
|
||||||
|
sv, err := callValuerValue(vr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !driver.IsValue(sv) {
|
||||||
|
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
|
||||||
|
}
|
||||||
|
return sv, nil
|
||||||
|
}
|
||||||
|
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
switch rv.Kind() {
|
switch rv.Kind() {
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
// indirect pointers
|
// indirect pointers
|
||||||
if rv.IsNil() {
|
if rv.IsNil() {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
} else {
|
||||||
return c.ConvertValue(rv.Elem().Interface())
|
return c.ConvertValue(rv.Elem().Interface())
|
||||||
|
}
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
return rv.Int(), nil
|
return rv.Int(), nil
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
||||||
|
@ -145,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
||||||
return int64(u64), nil
|
return int64(u64), nil
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return rv.Float(), nil
|
return rv.Float(), nil
|
||||||
|
case reflect.Bool:
|
||||||
|
return rv.Bool(), nil
|
||||||
|
case reflect.Slice:
|
||||||
|
ek := rv.Type().Elem().Kind()
|
||||||
|
if ek == reflect.Uint8 {
|
||||||
|
return rv.Bytes(), nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
|
||||||
|
case reflect.String:
|
||||||
|
return rv.String(), nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
|
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||||
|
|
||||||
|
// callValuerValue returns vr.Value(), with one exception:
|
||||||
|
// If vr.Value is an auto-generated method on a pointer type and the
|
||||||
|
// pointer is nil, it would panic at runtime in the panicwrap
|
||||||
|
// method. Treat it like nil instead.
|
||||||
|
//
|
||||||
|
// This is so people can implement driver.Value on value types and
|
||||||
|
// still use nil pointers to those types to mean nil/NULL, just like
|
||||||
|
// string/*string.
|
||||||
|
//
|
||||||
|
// This is an exact copy of the same-named unexported function from the
|
||||||
|
// database/sql package.
|
||||||
|
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
||||||
|
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
||||||
|
rv.IsNil() &&
|
||||||
|
rv.Type().Elem().Implements(valuerReflectType) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return vr.Value()
|
||||||
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ type mysqlTx struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *mysqlTx) Commit() (err error) {
|
func (tx *mysqlTx) Commit() (err error) {
|
||||||
if tx.mc == nil || tx.mc.netConn == nil {
|
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||||
return ErrInvalidConn
|
return ErrInvalidConn
|
||||||
}
|
}
|
||||||
err = tx.mc.exec("COMMIT")
|
err = tx.mc.exec("COMMIT")
|
||||||
|
@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *mysqlTx) Rollback() (err error) {
|
func (tx *mysqlTx) Rollback() (err error) {
|
||||||
if tx.mc == nil || tx.mc.netConn == nil {
|
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||||
return ErrInvalidConn
|
return ErrInvalidConn
|
||||||
}
|
}
|
||||||
err = tx.mc.exec("ROLLBACK")
|
err = tx.mc.exec("ROLLBACK")
|
||||||
|
|
|
@ -9,23 +9,29 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Registry for custom tls.Configs
|
||||||
var (
|
var (
|
||||||
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
|
tlsConfigLock sync.RWMutex
|
||||||
|
tlsConfigRegistry map[string]*tls.Config
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
||||||
// Use the key as a value in the DSN where tls=value.
|
// Use the key as a value in the DSN where tls=value.
|
||||||
//
|
//
|
||||||
|
// Note: The provided tls.Config is exclusively owned by the driver after
|
||||||
|
// registering it.
|
||||||
|
//
|
||||||
// rootCertPool := x509.NewCertPool()
|
// rootCertPool := x509.NewCertPool()
|
||||||
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
|
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
|
@ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error {
|
||||||
return fmt.Errorf("key '%s' is reserved", key)
|
return fmt.Errorf("key '%s' is reserved", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tlsConfigRegister == nil {
|
tlsConfigLock.Lock()
|
||||||
tlsConfigRegister = make(map[string]*tls.Config)
|
if tlsConfigRegistry == nil {
|
||||||
|
tlsConfigRegistry = make(map[string]*tls.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfigRegister[key] = config
|
tlsConfigRegistry[key] = config
|
||||||
|
tlsConfigLock.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeregisterTLSConfig removes the tls.Config associated with key.
|
// DeregisterTLSConfig removes the tls.Config associated with key.
|
||||||
func DeregisterTLSConfig(key string) {
|
func DeregisterTLSConfig(key string) {
|
||||||
if tlsConfigRegister != nil {
|
tlsConfigLock.Lock()
|
||||||
delete(tlsConfigRegister, key)
|
if tlsConfigRegistry != nil {
|
||||||
|
delete(tlsConfigRegistry, key)
|
||||||
}
|
}
|
||||||
|
tlsConfigLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTLSConfigClone(key string) (config *tls.Config) {
|
||||||
|
tlsConfigLock.RLock()
|
||||||
|
if v, ok := tlsConfigRegistry[key]; ok {
|
||||||
|
config = cloneTLSConfig(v)
|
||||||
|
}
|
||||||
|
tlsConfigLock.RUnlock()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the bool value of the input.
|
// Returns the bool value of the input.
|
||||||
|
@ -80,119 +99,6 @@ func readBool(input string) (value bool, valid bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/******************************************************************************
|
|
||||||
* Authentication *
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
// Encrypt password using 4.1+ method
|
|
||||||
func scramblePassword(scramble, password []byte) []byte {
|
|
||||||
if len(password) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// stage1Hash = SHA1(password)
|
|
||||||
crypt := sha1.New()
|
|
||||||
crypt.Write(password)
|
|
||||||
stage1 := crypt.Sum(nil)
|
|
||||||
|
|
||||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
|
||||||
// inner Hash
|
|
||||||
crypt.Reset()
|
|
||||||
crypt.Write(stage1)
|
|
||||||
hash := crypt.Sum(nil)
|
|
||||||
|
|
||||||
// outer Hash
|
|
||||||
crypt.Reset()
|
|
||||||
crypt.Write(scramble)
|
|
||||||
crypt.Write(hash)
|
|
||||||
scramble = crypt.Sum(nil)
|
|
||||||
|
|
||||||
// token = scrambleHash XOR stage1Hash
|
|
||||||
for i := range scramble {
|
|
||||||
scramble[i] ^= stage1[i]
|
|
||||||
}
|
|
||||||
return scramble
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt password using pre 4.1 (old password) method
|
|
||||||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
|
|
||||||
type myRnd struct {
|
|
||||||
seed1, seed2 uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
const myRndMaxVal = 0x3FFFFFFF
|
|
||||||
|
|
||||||
// Pseudo random number generator
|
|
||||||
func newMyRnd(seed1, seed2 uint32) *myRnd {
|
|
||||||
return &myRnd{
|
|
||||||
seed1: seed1 % myRndMaxVal,
|
|
||||||
seed2: seed2 % myRndMaxVal,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tested to be equivalent to MariaDB's floating point variant
|
|
||||||
// http://play.golang.org/p/QHvhd4qved
|
|
||||||
// http://play.golang.org/p/RG0q4ElWDx
|
|
||||||
func (r *myRnd) NextByte() byte {
|
|
||||||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
|
|
||||||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
|
|
||||||
|
|
||||||
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate binary hash from byte string using insecure pre 4.1 method
|
|
||||||
func pwHash(password []byte) (result [2]uint32) {
|
|
||||||
var add uint32 = 7
|
|
||||||
var tmp uint32
|
|
||||||
|
|
||||||
result[0] = 1345345333
|
|
||||||
result[1] = 0x12345671
|
|
||||||
|
|
||||||
for _, c := range password {
|
|
||||||
// skip spaces and tabs in password
|
|
||||||
if c == ' ' || c == '\t' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp = uint32(c)
|
|
||||||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
|
|
||||||
result[1] += (result[1] << 8) ^ result[0]
|
|
||||||
add += tmp
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove sign bit (1<<31)-1)
|
|
||||||
result[0] &= 0x7FFFFFFF
|
|
||||||
result[1] &= 0x7FFFFFFF
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt password using insecure pre 4.1 method
|
|
||||||
func scrambleOldPassword(scramble, password []byte) []byte {
|
|
||||||
if len(password) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
scramble = scramble[:8]
|
|
||||||
|
|
||||||
hashPw := pwHash(password)
|
|
||||||
hashSc := pwHash(scramble)
|
|
||||||
|
|
||||||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
|
|
||||||
|
|
||||||
var out [8]byte
|
|
||||||
for i := range out {
|
|
||||||
out[i] = r.NextByte() + 64
|
|
||||||
}
|
|
||||||
|
|
||||||
mask := r.NextByte()
|
|
||||||
for i := range out {
|
|
||||||
out[i] ^= mask
|
|
||||||
}
|
|
||||||
|
|
||||||
return out[:]
|
|
||||||
}
|
|
||||||
|
|
||||||
/******************************************************************************
|
/******************************************************************************
|
||||||
* Time related utils *
|
* Time related utils *
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
@ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
|
||||||
|
|
||||||
// Check data length
|
// Check data length
|
||||||
if len(b) >= n {
|
if len(b) >= n {
|
||||||
return b[n-int(num) : n], false, n, nil
|
return b[n-int(num) : n : n], false, n, nil
|
||||||
}
|
}
|
||||||
return nil, false, n, io.EOF
|
return nil, false, n, io.EOF
|
||||||
}
|
}
|
||||||
|
@ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return 0, true, 1
|
return 0, true, 1
|
||||||
}
|
}
|
||||||
switch b[0] {
|
|
||||||
|
|
||||||
|
switch b[0] {
|
||||||
// 251: NULL
|
// 251: NULL
|
||||||
case 0xfb:
|
case 0xfb:
|
||||||
return 0, true, 1
|
return 0, true, 1
|
||||||
|
@ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte {
|
||||||
|
|
||||||
return buf[:pos]
|
return buf[:pos]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/******************************************************************************
|
||||||
|
* Sync utils *
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
// noCopy may be embedded into structs which must not be copied
|
||||||
|
// after the first use.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
|
||||||
|
// for details.
|
||||||
|
type noCopy struct{}
|
||||||
|
|
||||||
|
// Lock is a no-op used by -copylocks checker from `go vet`.
|
||||||
|
func (*noCopy) Lock() {}
|
||||||
|
|
||||||
|
// atomicBool is a wrapper around uint32 for usage as a boolean value with
|
||||||
|
// atomic access.
|
||||||
|
type atomicBool struct {
|
||||||
|
_noCopy noCopy
|
||||||
|
value uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSet returns wether the current boolean value is true
|
||||||
|
func (ab *atomicBool) IsSet() bool {
|
||||||
|
return atomic.LoadUint32(&ab.value) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the value of the bool regardless of the previous value
|
||||||
|
func (ab *atomicBool) Set(value bool) {
|
||||||
|
if value {
|
||||||
|
atomic.StoreUint32(&ab.value, 1)
|
||||||
|
} else {
|
||||||
|
atomic.StoreUint32(&ab.value, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrySet sets the value of the bool and returns wether the value changed
|
||||||
|
func (ab *atomicBool) TrySet(value bool) bool {
|
||||||
|
if value {
|
||||||
|
return atomic.SwapUint32(&ab.value, 1) == 0
|
||||||
|
}
|
||||||
|
return atomic.SwapUint32(&ab.value, 0) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicError is a wrapper for atomically accessed error values
|
||||||
|
type atomicError struct {
|
||||||
|
_noCopy noCopy
|
||||||
|
value atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the error value regardless of the previous value.
|
||||||
|
// The value must not be nil
|
||||||
|
func (ae *atomicError) Set(value error) {
|
||||||
|
ae.value.Store(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the current error value
|
||||||
|
func (ae *atomicError) Value() error {
|
||||||
|
if v := ae.value.Load(); v != nil {
|
||||||
|
// this will panic if the value doesn't implement the error interface
|
||||||
|
return v.(error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||||
|
//
|
||||||
|
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
// +build go1.7
|
||||||
|
// +build !go1.8
|
||||||
|
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
func cloneTLSConfig(c *tls.Config) *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: c.Rand,
|
||||||
|
Time: c.Time,
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
NameToCertificate: c.NameToCertificate,
|
||||||
|
GetCertificate: c.GetCertificate,
|
||||||
|
RootCAs: c.RootCAs,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
ClientAuth: c.ClientAuth,
|
||||||
|
ClientCAs: c.ClientCAs,
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
|
ClientSessionCache: c.ClientSessionCache,
|
||||||
|
MinVersion: c.MinVersion,
|
||||||
|
MaxVersion: c.MaxVersion,
|
||||||
|
CurvePreferences: c.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: c.Renegotiation,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||||
|
//
|
||||||
|
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cloneTLSConfig(c *tls.Config) *tls.Config {
|
||||||
|
return c.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
|
||||||
|
dargs := make([]driver.Value, len(named))
|
||||||
|
for n, param := range named {
|
||||||
|
if len(param.Name) > 0 {
|
||||||
|
// TODO: support the use of Named Parameters #561
|
||||||
|
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
|
||||||
|
}
|
||||||
|
dargs[n] = param.Value
|
||||||
|
}
|
||||||
|
return dargs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
|
||||||
|
switch sql.IsolationLevel(level) {
|
||||||
|
case sql.LevelRepeatableRead:
|
||||||
|
return "REPEATABLE READ", nil
|
||||||
|
case sql.LevelReadCommitted:
|
||||||
|
return "READ COMMITTED", nil
|
||||||
|
case sql.LevelReadUncommitted:
|
||||||
|
return "READ UNCOMMITTED", nil
|
||||||
|
case sql.LevelSerializable:
|
||||||
|
return "SERIALIZABLE", nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue