diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml deleted file mode 100644 index bad8149..0000000 --- a/.github/workflows/reviewdog.yml +++ /dev/null @@ -1,27 +0,0 @@ -# This is a basic workflow to help you get started with Actions - -name: CI - -# Controls when the action will run. Triggers the workflow on push or pull request -# events but only for the master branch -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -# A workflow run is made up of one or more jobs that can run sequentially or in parallel -jobs: - # This workflow contains a single job called "build" - build: - # The type of runner that the job will run on - runs-on: ubuntu-latest - - # Steps represent a sequence of tasks that will be executed as part of the job - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - - # Runs a single command using the runners shell - - name: Run golangci-lint with reviewdog - uses: reviewdog/action-golangci-lint@v1.1.3 diff --git a/.reviewdog.yml b/.reviewdog.yml new file mode 100644 index 0000000..2e243ff --- /dev/null +++ b/.reviewdog.yml @@ -0,0 +1,8 @@ +runner: + golint: + cmd: golint ./... + errorformat: + - "%f:%l:%c: %m" + level: warning + govet: + cmd: go vet -all . diff --git a/.travis.yml b/.travis.yml index 849897f..814f0f1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,7 @@ matrix: - env: DB=MYSQL5.7 sudo: required dist: trusty - go: 1.14.x + go: 1.13.x services: - docker before_install: @@ -14,34 +14,20 @@ matrix: - go get -u github.com/golang/dep/cmd/dep - dep ensure - docker pull mysql:5.7 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_ALLOW_EMPTY_PASSWORD=yes mysql:5.7 --innodb_log_file_size=256MB + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_USER=msandbox -e MYSQL_PASSWORD=msandbox mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB - sleep 30 - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - - mysql < testdata/schema/sakila.sql before_script: - - export TEST_DSN="root:@tcp(127.0.0.1:3307)/sakila?parseTime=true" + - export TEST_DSN="msandbox:msandbox@tcp(127.0.0.1:3307)/sakila?parseTime=true" - - env: DB=MYSQL5.6 - sudo: required - dist: trusty - go: 1.14.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get -u github.com/golang/dep/cmd/dep - - dep ensure - - docker pull mysql:5.6 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_ALLOW_EMPTY_PASSWORD=yes mysql:5.6 - --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB - - sleep 30 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - - mysql < testdata/schema/sakila.sql - before_script: - - export TEST_DSN="root@tcp(127.0.0.1:3307)/sakila?parseTime=true" +install: + - mkdir -p ~/bin/ && export PATH="~/bin/:$PATH" + - curl -sfL https://raw.githubusercontent.com/reviewdog/reviewdog/master/install.sh| sh -s -- -b ~/bin + - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b ~/bin/ v1.25.1 script: - - go test -v -race ./... + - ~/bin/golangci-lint run --out-format=line-number | env REVIEWDOG_GITHUB_API_TOKEN=${{ secret.token }} ~/bin/reviewdog -f=golangci-lint -level=error -reporter=github-pr-review -name='Required checks' +# - ~/bin/reviewdog -conf=.reviewdog.yml -reporter=github-pr-check + +after_success: diff --git a/.travis/docker.cnf b/.travis/docker.cnf index 40226a3..ab18907 100644 --- a/.travis/docker.cnf +++ b/.travis/docker.cnf @@ -1,4 +1,5 @@ [client] -user = root +user = msandbox +password=msandbox host = 127.0.0.1 port = 3307 diff --git a/Gopkg.lock b/Gopkg.lock index afccfe0..a640223 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,23 +2,31 @@ [[projects]] - branch = "master" - digest = "1:125f098c5294814b5d9ce2194f03a435c35a6208f029e61681bed9a893d68133" - name = "github.com/alecthomas/template" - packages = [ - ".", - "parse", - ] + digest = "1:3ff78268777dc67ff46530ab776eeb0a819c49edee5348b64c04db116437befd" + name = "github.com/alecthomas/kong" + packages = ["."] pruneopts = "UT" - revision = "fb15b899a75114aa79cc930e33c46b577cc664b1" + revision = "ed7caf6841a0418e43c80c98f9f22ce80b437eed" + version = "v0.2.9" [[projects]] - branch = "master" - digest = "1:f780e2e814981de9dbdc2ca49d6c91d4aa4e2a1639ebc4106fafb0dc54fa5bd8" - name = "github.com/alecthomas/units" + digest = "1:d2e776e9706a8e4e7cb4ada3e6c49058a3287b3bf62a05c22585c1d0f07a4dae" + name = "github.com/apoorvam/goterminal" packages = ["."] pruneopts = "UT" - revision = "f65c72e2690dc4b403c8bd637baf4611cd4c069b" + revision = "614d345c47e510f5bc95fef660f25d0c00d224e4" + version = "v1.0" + +[[projects]] + digest = "1:4fc7ef733af25ff9e4a4799eb7d541b8d74f676896a3fd4446adefc86253afd8" + name = "github.com/brianvoe/gofakeit" + packages = [ + ".", + "data", + ] + pruneopts = "UT" + revision = "19b40d7d6241163eba8b06ddec1041a36db06c54" + version = "v3.20.2" [[projects]] digest = "1:b06d27db2f87588507f245345f329b4192809e72ab32385a26a6eb25988eb82e" @@ -28,39 +36,20 @@ revision = "031be390f409fb4bac8fb299e3bcd101479f89f8" [[projects]] - digest = "1:2d8307a7345570ece92d85e80f17c9de719fa44a11f5ccd828ab2bb4a328314c" - name = "github.com/go-ini/ini" - packages = ["."] + digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" + name = "github.com/davecgh/go-spew" + packages = ["spew"] pruneopts = "UT" - revision = "700781759788472518f4ea52321519b066061a7b" - version = "v1.48.0" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" [[projects]] - digest = "1:ec6f9bf5e274c833c911923c9193867f3f18788c461f76f05f62bb1510e0ae65" + digest = "1:a3af33fd30536d9c5a7b6ad542390919e57c1125097a194c7b3d9904109a0611" name = "github.com/go-sql-driver/mysql" packages = ["."] pruneopts = "UT" - revision = "72cd26f257d44c1114970e19afddcd812016007e" - version = "v1.4.1" - -[[projects]] - digest = "1:e6e8ee8a9aa4efb33290bd5e6615822635884bc007923359e3a10b57eff36bb3" - name = "github.com/gosuri/uilive" - packages = ["."] - pruneopts = "UT" - revision = "4512d98b127f3f3a1b7c3cf1104969fdd17b31d9" - version = "v0.0.3" - -[[projects]] - digest = "1:f6ae4b3c3d4411bffa0b8045fff29974e1eb866e1828786972059f89a130aecf" - name = "github.com/gosuri/uiprogress" - packages = [ - ".", - "util/strutil", - ] - pruneopts = "UT" - revision = "d0567a9d84a1c40dd7568115ea66f4887bf57b33" - version = "0.0.1" + revision = "17ef3dd9d98b69acec3e85878995ada9533a9370" + version = "v1.5.0" [[projects]] digest = "1:88e0b0baeb9072f0a4afbcf12dda615fc8be001d1802357538591155998da21b" @@ -79,90 +68,94 @@ revision = "4178557ae428460c3780a381c824a1f3aceb6325" [[projects]] - digest = "1:31e761d97c76151dde79e9d28964a812c46efc5baee4085b86f68f0c654450de" + digest = "1:09cb61dc19af93deae01587e2fdb1c081e0bf48f1a5ad5fa24f48750dc57dce8" name = "github.com/konsorten/go-windows-terminal-sequences" packages = ["."] pruneopts = "UT" - revision = "f55edac94c9bbba5d6182a4be46d86a2c9b5b50e" - version = "v1.0.2" + revision = "edb144dfd453055e1e49a3d8b410a660b5a87613" + version = "v1.0.3" [[projects]] - digest = "1:ca955a9cd5b50b0f43d2cc3aeb35c951473eeca41b34eb67507f1dbcc0542394" - name = "github.com/kr/pretty" + digest = "1:0c58d31abe2a2ccb429c559b6292e7df89dcda675456fecc282fa90aa08273eb" + name = "github.com/mattn/go-isatty" packages = ["."] pruneopts = "UT" - revision = "73f6ac0b30a98e433b289500d779f50c1a6f0712" - version = "v0.1.0" + revision = "7b513a986450394f7bbf1476909911b3aa3a55ce" + version = "v0.0.12" [[projects]] - digest = "1:15b5cc79aad436d47019f814fde81a10221c740dc8ddf769221a65097fb6c2e9" - name = "github.com/kr/text" + digest = "1:9e1d37b58d17113ec3cb5608ac0382313c5b59470b94ed97d0976e69c7022314" + name = "github.com/pkg/errors" packages = ["."] pruneopts = "UT" - revision = "e2ffdb16a802fe2bb95e2e35ff34f0e53aeef34f" - version = "v0.1.0" + revision = "614d223910a179a466c1767a985424175c39b465" + version = "v0.9.1" [[projects]] - digest = "1:d62282425ffb75047679d7e2c3b980eea7f82c05ef5fb9142ee617ebac6e7432" - name = "github.com/mattn/go-isatty" - packages = ["."] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] pruneopts = "UT" - revision = "88ba11cfdc67c7588b30042edf244b2875f892b6" - version = "v0.0.10" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" [[projects]] - digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" - name = "github.com/pkg/errors" + digest = "1:05eebdd5727fea23083fce0d98d307d70c86baed644178e81608aaa9f09ea469" + name = "github.com/sirupsen/logrus" packages = ["."] pruneopts = "UT" - revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" - version = "v0.8.1" + revision = "60c74ad9be0d874af0ab0daef6ab07c5c5911f0d" + version = "v1.6.0" [[projects]] - digest = "1:04457f9f6f3ffc5fea48e71d62f2ca256637dee0a04d710288e27e05c8b41976" - name = "github.com/sirupsen/logrus" - packages = ["."] + digest = "1:5e8f46b412421d2d6cceea845d28ac46f3f5a5f60a6e86f0ee75e24fd43a02a9" + name = "github.com/stretchr/testify" + packages = ["assert"] pruneopts = "UT" - revision = "839c75faf7f98a33d445d181f3018b5c3409a45e" - version = "v1.4.2" + revision = "3ebf1ddaeb260c4b1ae502a01c7844fa8c1fa0e9" + version = "v1.5.1" [[projects]] branch = "master" - digest = "1:7c927f17d868be652a4cfe7de23e4292dea5b14d974a1d536e3b7cb7e79fd695" + digest = "1:d85f9416a95aba57c985b1f48f900da97404d8258c3374cce4e49616c57bf8c9" name = "golang.org/x/sys" - packages = ["unix"] + packages = [ + "internal/unsafeheader", + "unix", + ] pruneopts = "UT" - revision = "b09406accb4736d857a32bf9444cd7edae2ffa79" + revision = "7e40ca221e254089b05fb9efcd69844a57f6a367" [[projects]] - digest = "1:c25289f43ac4a68d88b02245742347c94f1e108c534dda442188015ff80669b3" - name = "google.golang.org/appengine" - packages = ["cloudsql"] + digest = "1:2d32701bd465142606b3a84562c6c63b2950f0fb1bf82f231116a7b89de40e2d" + name = "gopkg.in/ini.v1" + packages = ["."] pruneopts = "UT" - revision = "971852bfffca25b069c31162ae8f247a3dba083b" - version = "v1.6.5" + revision = "ad8a10643d24d67f464955e42b23e0d42d60fcb4" + version = "v1.56.0" [[projects]] - digest = "1:c06d9e11d955af78ac3bbb26bd02e01d2f61f689e1a3bce2ef6fb683ef8a7f2d" - name = "gopkg.in/alecthomas/kingpin.v2" + digest = "1:55b110c99c5fdc4f14930747326acce56b52cfce60b24b1c03ef686ac0e46bb1" + name = "gopkg.in/yaml.v2" packages = ["."] pruneopts = "UT" - revision = "947dcec5ba9c011838740e680966fd7087a71d0d" - version = "v2.2.6" + revision = "53403b58ad1b561927d19068c655246f2db79d48" + version = "v2.2.8" [solve-meta] analyzer-name = "dep" analyzer-version = 1 input-imports = [ - "github.com/go-ini/ini", + "github.com/alecthomas/kong", + "github.com/apoorvam/goterminal", + "github.com/brianvoe/gofakeit", "github.com/go-sql-driver/mysql", - "github.com/gosuri/uiprogress", "github.com/hashicorp/go-version", "github.com/icrowley/fake", - "github.com/kr/pretty", "github.com/pkg/errors", "github.com/sirupsen/logrus", - "gopkg.in/alecthomas/kingpin.v2", + "github.com/stretchr/testify/assert", + "gopkg.in/ini.v1", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index f6c7e07..edcc18d 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -26,16 +26,20 @@ [[constraint]] - name = "github.com/go-ini/ini" - version = "1.48.0" + name = "github.com/alecthomas/kong" + version = "0.2.9" [[constraint]] - name = "github.com/go-sql-driver/mysql" - version = "1.4.1" + name = "github.com/apoorvam/goterminal" + version = "1.0.0" + +[[constraint]] + name = "github.com/brianvoe/gofakeit" + version = "3.20.2" [[constraint]] - name = "github.com/gosuri/uiprogress" - version = "0.0.1" + name = "github.com/go-sql-driver/mysql" + version = "1.5.0" [[constraint]] name = "github.com/hashicorp/go-version" @@ -45,21 +49,21 @@ branch = "master" name = "github.com/icrowley/fake" -[[constraint]] - name = "github.com/kr/pretty" - version = "0.1.0" - [[constraint]] name = "github.com/pkg/errors" - version = "0.8.1" + version = "0.9.1" [[constraint]] name = "github.com/sirupsen/logrus" - version = "1.4.2" + version = "1.6.0" + +[[constraint]] + name = "github.com/stretchr/testify" + version = "1.5.1" [[constraint]] - name = "gopkg.in/alecthomas/kingpin.v2" - version = "2.2.6" + name = "gopkg.in/ini.v1" + version = "1.56.0" [prune] go-tests = true diff --git a/cmd/run.go b/cmd/run.go new file mode 100644 index 0000000..d5277f9 --- /dev/null +++ b/cmd/run.go @@ -0,0 +1,147 @@ +package cmd + +import ( + "database/sql" + "fmt" + "net" + "os" + "sync" + + "github.com/Percona-Lab/mysql_random_data_load/internal/insert" + "github.com/Percona-Lab/mysql_random_data_load/internal/ptdsn" + "github.com/Percona-Lab/mysql_random_data_load/tableparser" + "github.com/apoorvam/goterminal" + "github.com/go-sql-driver/mysql" + "github.com/pkg/errors" +) + +type RunCmd struct { + DSN string `name:"dsn" help:"Connection string in Pecona toolkit"` + Database string `name:"database" short:"d" help:"Database schema"` + Table string `name:"table" short:"t" help:"Table name"` + Host string `name:"host" short:"H" help:"Host name/IP"` + Port int `name:"port" short:"P" help:"MySQL port to connect to"` + User string `name:"user" short:"u" help:"MySQL username"` + Password string `name:"password" short:"p" help:"MySQL password"` + ConfigFile string `name:"config-file" help:"MySQL config file"` + + Rows int64 `name:"rows" required:"true" help:"Number of rows to insert"` + BulkSize int64 `name:"bulk-size" help:"Number of rows per insert statement" default:"1000"` + DryRun bool `name:"dry-run" help:"Print queries to the standard output instead of inserting them into the db"` + Quiet bool `name:"quiet" help:"Do not print progress bar"` +} + +// Run starts inserting data. +func (cmd *RunCmd) Run() error { + dsn, err := cmd.mysqlParams() + if err != nil { + return err + } + + db, err := cmd.connect(dsn) + if err != nil { + return err + } + + table, err := tableparser.New(db, dsn.Database, dsn.Table) + if err != nil { + return errors.Wrap(err, "cannot parse table") + } + + _, err = cmd.run(db, table) + return err +} + +func (cmd *RunCmd) run(db *sql.DB, table *tableparser.Table) (int64, error) { + ins := insert.New(db, table) + wg := &sync.WaitGroup{} + + if !cmd.Quiet && !cmd.DryRun { + wg.Add(1) + startProgressBar(cmd.Rows, ins.NotifyChan(), wg) + } + + if cmd.DryRun { + return ins.DryRun(cmd.Rows, cmd.BulkSize) + } + + n, err := ins.Run(cmd.Rows, cmd.BulkSize) + wg.Wait() + return n, err +} + +func startProgressBar(total int64, c chan int64, wg *sync.WaitGroup) { + go func() { + writer := goterminal.New(os.Stdout) + var count int64 + for n := range c { + count += n + writer.Clear() + fmt.Fprintf(writer, "Writing (%d/%d) rows...\n", count, total) + writer.Print() //nolint + } + writer.Reset() + wg.Done() + }() +} + +func (cmd *RunCmd) connect(dsn *ptdsn.PTDSN) (*sql.DB, error) { + netType := "tcp" + address := net.JoinHostPort(dsn.Host, fmt.Sprintf("%d", dsn.Port)) + + if dsn.Host == "localhost" { + netType = "unix" + address = dsn.Host + } + + cfg := &mysql.Config{ + User: dsn.User, + Passwd: dsn.Password, + Net: netType, + Addr: address, + DBName: dsn.Database, + AllowCleartextPasswords: true, + AllowNativePasswords: true, + AllowOldPasswords: true, + CheckConnLiveness: true, + ParseTime: true, + } + + return sql.Open("mysql", cfg.FormatDSN()) +} + +func (cmd *RunCmd) mysqlParams() (*ptdsn.PTDSN, error) { + dsn, err := ptdsn.Parse(cmd.DSN) + if err != nil { + return nil, errors.Wrap(err, "cannot get connection parameters") + } + + if cmd.Database != "" { + dsn.Database = cmd.Database + } + if cmd.Table != "" { + dsn.Table = cmd.Table + } + + if dsn.Database == "" { + return nil, fmt.Errorf("you need to specify a database") + } + if dsn.Table == "" { + return nil, fmt.Errorf("you need to specify a table name") + } + + if cmd.Host != "" { + dsn.Host = cmd.Host + } + if cmd.Port != 0 { + dsn.Port = cmd.Port + } + if cmd.User != "" { + dsn.User = cmd.User + } + if cmd.Password != "" { + dsn.Password = cmd.Password + } + + return dsn, nil +} diff --git a/docker-compose.yml b/docker-compose.yml index a8c9294..a781aeb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,35 +1,46 @@ version: '3' services: + mysql5.5: + image: ${MYSQL_IMAGE:-mysql:5.5} + ports: + - "3305:3306" + environment: + - MYSQL_ROOT_PASSWORD=root + - MYSQL_USER=msandbox + - MYSQ_PASSWORD=msandbox + command: --performance-schema --secure-file-priv="" + volumes: + - ./testdata/schema/:/docker-entrypoint-initdb.d/:rw mysql5.6: image: ${MYSQL_IMAGE:-mysql:5.6} ports: - - ${MYSQL_HOST:-127.0.0.1}:${MYSQL_PORT:-3306}:3306 + - "3306:3306" environment: - - MYSQL_ALLOW_EMPTY_PASSWORD=yes - # MariaDB >= 10.0.12 doesn't enable Performance Schema by default so we need to do it manually - # https://mariadb.com/kb/en/mariadb/performance-schema-overview/#activating-the-performance-schema + - MYSQL_ROOT_PASSWORD=root + - MYSQL_USER=msandbox + - MYSQ_PASSWORD=msandbox command: --performance-schema --secure-file-priv="" volumes: - ./testdata/schema/:/docker-entrypoint-initdb.d/:rw mysql5.7: image: ${MYSQL_IMAGE:-mysql:5.7} ports: - - ${MYSQL_HOST:-127.0.0.1}:${MYSQL_PORT:-3307}:3306 + - "3307:3306" environment: - - MYSQL_ALLOW_EMPTY_PASSWORD=yes - # MariaDB >= 10.0.12 doesn't enable Performance Schema by default so we need to do it manually - # https://mariadb.com/kb/en/mariadb/performance-schema-overview/#activating-the-performance-schema + - MYSQL_ROOT_PASSWORD=root + - MYSQL_USER=msandbox + - MYSQ_PASSWORD=msandbox command: --performance-schema --secure-file-priv="" volumes: - ./testdata/schema/:/docker-entrypoint-initdb.d/:rw mysql8.0: image: ${MYSQL_IMAGE:-mysql:8.0.3} ports: - - ${MYSQL_HOST:-127.0.0.1}:${MYSQL_PORT:-3308}:3306 + - "3308:3306" environment: - - MYSQL_ALLOW_EMPTY_PASSWORD=yes - # MariaDB >= 10.0.12 doesn't enable Performance Schema by default so we need to do it manually - # https://mariadb.com/kb/en/mariadb/performance-schema-overview/#activating-the-performance-schema + - MYSQL_ROOT_PASSWORD=root + - MYSQL_USER=msandbox + - MYSQ_PASSWORD=msandbox command: --performance-schema --secure-file-priv="" volumes: - ./testdata/schema/:/docker-entrypoint-initdb.d/:rw diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8fae1e3 --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module github.com/Percona-Lab/mysql_random_data_load + +go 1.16 + +require ( + github.com/alecthomas/kong v0.2.9 + github.com/apoorvam/goterminal v0.0.0-20180523175556-614d345c47e5 + github.com/brianvoe/gofakeit v3.18.0+incompatible + github.com/corpix/uarand v0.0.0-20170723150923-031be390f409 // indirect + github.com/go-sql-driver/mysql v1.5.0 + github.com/hashicorp/go-version v1.2.0 + github.com/icrowley/fake v0.0.0-20180203215853-4178557ae428 + github.com/mattn/go-isatty v0.0.12 // indirect + github.com/pkg/errors v0.9.1 + github.com/sirupsen/logrus v1.6.0 + github.com/smartystreets/goconvey v1.6.4 // indirect + github.com/stretchr/testify v1.5.1 + golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25 // indirect + gopkg.in/ini.v1 v1.56.0 + gopkg.in/yaml.v2 v2.2.8 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9d3cf84 --- /dev/null +++ b/go.sum @@ -0,0 +1,56 @@ +github.com/alecthomas/kong v0.2.9 h1:WGuTS/N2/NQ/9LymVqpr1ifZ4EEkQPvwFHqZs6ak5IU= +github.com/alecthomas/kong v0.2.9/go.mod h1:kQOmtJgV+Lb4aj+I2LEn40cbtawdWJ9Y8QLq+lElKxE= +github.com/apoorvam/goterminal v0.0.0-20180523175556-614d345c47e5 h1:VYqcjykqpcq262cDxBAkAelSdg6HETkxgwzQRTS40Aw= +github.com/apoorvam/goterminal v0.0.0-20180523175556-614d345c47e5/go.mod h1:E7x8aDc3AQzDKjEoIZCt+XYheHk2OkP+p2UgeNjecH8= +github.com/brianvoe/gofakeit v3.18.0+incompatible h1:wDOmHc9DLG4nRjUVVaxA+CEglKOW72Y5+4WNxUIkjM8= +github.com/brianvoe/gofakeit v3.18.0+incompatible/go.mod h1:kfwdRA90vvNhPutZWfH7WPaDzUjz+CZFqG+rPkOjGOc= +github.com/corpix/uarand v0.0.0-20170723150923-031be390f409 h1:9A+mfQmwzZ6KwUXPc8nHxFtKgn9VIvO3gXAOspIcE3s= +github.com/corpix/uarand v0.0.0-20170723150923-031be390f409/go.mod h1:JSm890tOkDN+M1jqN8pUGDKnzJrsVbJwSMHBY4zwz7M= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/hashicorp/go-version v1.2.0 h1:3vNe/fWF5CBgRIguda1meWhsZHy3m8gCJ5wx+dIzX/E= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/icrowley/fake v0.0.0-20180203215853-4178557ae428 h1:Mo9W14pwbO9VfRe+ygqZ8dFbPpoIK1HFrG/zjTuQ+nc= +github.com/icrowley/fake v0.0.0-20180203215853-4178557ae428/go.mod h1:uhpZMVGznybq1itEKXj6RYw9I71qK4kH+OGMjRC4KEo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25 h1:OKbAoGs4fGM5cPLlVQLZGYkFC8OnOfgo6tt0Smf9XhM= +golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.56.0 h1:DPMeDvGTM54DXbPkVIZsp19fp/I2K7zwA/itHYHKo8Y= +gopkg.in/ini.v1 v1.56.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/getters/getters.go b/internal/getters/getters.go index 3d3c9aa..fec1195 100644 --- a/internal/getters/getters.go +++ b/internal/getters/getters.go @@ -1,11 +1,11 @@ package getters // All types defined here satisfy the Getter interface -// type Getter interface { -// Value() interface{} -// Quote() string -// String() string -// } +type Getter interface { + Value() interface{} + Quote() string + String() string +} const ( nilFrequency = 10 diff --git a/internal/getters/string.go b/internal/getters/string.go index 3e5f674..072471b 100644 --- a/internal/getters/string.go +++ b/internal/getters/string.go @@ -3,36 +3,45 @@ package getters import ( "fmt" "math/rand" + "regexp" - "github.com/icrowley/fake" + "github.com/brianvoe/gofakeit" ) // RandomString getter type RandomString struct { name string - maxSize int64 + maxSize uint64 allowNull bool + fn func() string } +var ( + emailRe = regexp.MustCompile(`email`) + firstNameRe = regexp.MustCompile(`first.*name`) + lastNameRe = regexp.MustCompile(`last.*name`) + nameRe = regexp.MustCompile(`name`) + phoneRe = regexp.MustCompile(`phone`) + zipRe = regexp.MustCompile(`zip`) + colorRe = regexp.MustCompile(`color`) + ipAddressRe = regexp.MustCompile(`ip.*(?:address)*`) + addressRe = regexp.MustCompile(`address`) + stateRe = regexp.MustCompile(`state`) + cityRe = regexp.MustCompile(`city`) + countryRe = regexp.MustCompile(`country`) + genderRe = regexp.MustCompile(`gender`) + urlRe = regexp.MustCompile(`url`) + domainre = regexp.MustCompile(`domain`) +) + func (r *RandomString) Value() interface{} { if r.allowNull && rand.Int63n(100) < nilFrequency { return nil } - var s string - maxSize := uint64(r.maxSize) - if maxSize == 0 { - maxSize = uint64(rand.Int63n(100)) - } - if maxSize <= 10 { - s = fake.FirstName() - } else if maxSize < 30 { - s = fake.FullName() - } else { - s = fake.Sentence() - } - if len(s) > int(maxSize) { - s = s[:int(maxSize)] + s := r.fn() + if len(s) > int(r.maxSize) { + s = s[:int(r.maxSize)] } return s } @@ -45,6 +54,7 @@ func (r *RandomString) String() string { return v.(string) } +// Quote returns a quoted string func (r *RandomString) Quote() string { v := r.Value() if v == nil { @@ -54,5 +64,36 @@ func (r *RandomString) Quote() string { } func NewRandomString(name string, maxSize int64, allowNull bool) *RandomString { - return &RandomString{name, maxSize, allowNull} + var fn func() string + + switch { + case emailRe.MatchString(name): + fn = gofakeit.Email + case firstNameRe.MatchString(name): + fn = gofakeit.FirstName + case lastNameRe.MatchString(name): + fn = gofakeit.LastName + case nameRe.MatchString(name): + fn = gofakeit.Name + case phoneRe.MatchString(name): + fn = gofakeit.PhoneFormatted + case zipRe.MatchString(name): + fn = gofakeit.Zip + case colorRe.MatchString(name): + fn = gofakeit.Color + case cityRe.MatchString(name): + fn = gofakeit.City + case countryRe.MatchString(name): + fn = gofakeit.Country + case addressRe.MatchString(name): + fn = gofakeit.Street + case ipAddressRe.MatchString(name): + fn = gofakeit.IPv4Address + default: + fn = func() string { + return gofakeit.Paragraph(10, 10, 10, " ") + } + } + + return &RandomString{name, uint64(maxSize), allowNull, fn} } diff --git a/internal/insert/insert.go b/internal/insert/insert.go new file mode 100644 index 0000000..61596d9 --- /dev/null +++ b/internal/insert/insert.go @@ -0,0 +1,332 @@ +package insert + +import ( + "database/sql" + "fmt" + "io" + "log" + "net/url" + "os" + "strings" + "time" + + "github.com/Percona-Lab/mysql_random_data_load/internal/getters" + "github.com/Percona-Lab/mysql_random_data_load/tableparser" +) + +type Insert struct { + db *sql.DB + table *tableparser.Table + writer io.Writer + notifyChan chan int64 +} + +var ( + maxValues = map[string]int64{ + "tinyint": 0XF, + "smallint": 0xFF, + "mediumint": 0x7FFFF, + "int": 0x7FFFFFFF, + "integer": 0x7FFFFFFF, + "float": 0x7FFFFFFF, + "decimal": 0x7FFFFFFF, + "double": 0x7FFFFFFF, + "bigint": 0x7FFFFFFFFFFFFFFF, + } +) + +// New returns a new Insert instance. +func New(db *sql.DB, table *tableparser.Table) *Insert { + return &Insert{ + db: db, + table: table, + writer: os.Stdout, + } +} + +// SetWriter lets you specify a custom writer. The default is Stdout. +func (in *Insert) SetWriter(w io.Writer) { + in.writer = w +} + +func (in *Insert) NotifyChan() chan int64 { + if in.notifyChan != nil { + close(in.notifyChan) + } + + in.notifyChan = make(chan int64) + return in.notifyChan +} + +// Run starts the insert process. +func (in *Insert) Run(count, bulksize int64) (int64, error) { + return in.run(count, bulksize, false) +} + +// DryRun starts writing the generated queries to the specified writer. +func (in *Insert) DryRun(count, bulksize int64) (int64, error) { + return in.run(count, bulksize, true) +} + +func (in *Insert) run(count int64, bulksize int64, dryRun bool) (int64, error) { + if in.notifyChan != nil { + defer close(in.notifyChan) + } + + // Example: want 11 rows with bulksize 4: + // count = int(11 / 4) = 2 -> 2 bulk inserts having 4 rows each = 8 rows + // We need to run this insert twice: + // INSERT INTO table (f1, f2) VALUES (?, ?), (?, ?), (?, ?), (?, ?) + // 1 2 3 4 + + // remainder = rows - count = 11 - 8 = 3 + // And then, we need to run this insert once to complete 11 rows + // INSERT INTO table (f1, f2) VALUES (?, ?), (?, ?), (?, ?) + // 1 2 3 + completeInserts := count / bulksize + remainder := count - completeInserts*bulksize + + var n, okCount int64 + var err error + + for i := int64(0); i < completeInserts; i++ { + n, err = in.insert(bulksize, dryRun) + okCount += n + if err != nil { + return okCount, err + } + in.notify(n) + } + + n, err = in.insert(remainder, dryRun) + okCount += n + in.notify(n) + + return okCount, err +} + +func (in *Insert) notify(n int64) { + if in.notifyChan != nil { + select { + case in.notifyChan <- n: + default: + } + } +} + +func (in *Insert) insert(count int64, dryRun bool) (int64, error) { + if count < 1 { + return 0, nil + } + values := make([]string, 0, count) + insertQuery := generateInsertStmt(in.table) + + for i := int64(0); i < count; i++ { + valueFns := makeValueFuncs(in.db, in.table.Fields, nil) + values = append(values, valueFns.String()) + } + + insertQuery += strings.Join(values, ",\n") + + if dryRun { + if _, err := in.writer.Write([]byte(insertQuery + "\n")); err != nil { + return 0, err + } + return count, nil + } + + res, err := in.db.Exec(insertQuery) + if err != nil { + fmt.Println(insertQuery) + return 0, err + } + ra, _ := res.RowsAffected() + return ra, err +} + +func generateInsertStmt(table *tableparser.Table) string { + fields := getFieldNames(table.Fields) + query := fmt.Sprintf("INSERT IGNORE INTO %s.%s (%s) VALUES \n", //nolint + backticks(table.Schema), + backticks(table.Name), + strings.Join(fields, ","), + ) + return query +} + +func getFieldNames(fields []tableparser.Field) []string { + fieldNames := make([]string, 0, len(fields)) + + for _, field := range fields { + if !isSupportedType(field.DataType) { + continue + } + if !field.IsNullable && field.ColumnKey == "PRI" && + strings.Contains(field.Extra, "auto_increment") { + continue + } + fieldNames = append(fieldNames, backticks(field.ColumnName)) + } + return fieldNames +} + +func isSupportedType(fieldType string) bool { + supportedTypes := map[string]bool{ + "tinyint": true, + "smallint": true, + "mediumint": true, + "int": true, + "integer": true, + "bigint": true, + "float": true, + "decimal": true, + "double": true, + "char": true, + "varchar": true, + "date": true, + "datetime": true, + "timestamp": true, + "time": true, + "year": true, + "tinyblob": true, + "tinytext": true, + "blob": true, + "text": true, + "mediumblob": true, + "mediumtext": true, + "longblob": true, + "longtext": true, + "binary": true, + "varbinary": true, + "enum": true, + "set": true, + } + _, ok := supportedTypes[fieldType] + return ok +} + +func makeValueFuncs(conn *sql.DB, fields []tableparser.Field, cg map[string]string) insertValues { + var values []getters.Getter + for _, field := range fields { + if !field.IsNullable && field.ColumnKey == "PRI" && strings.Contains(field.Extra, "auto_increment") { + continue + } + if field.Constraint != nil { + samples, err := getSamples(conn, field.Constraint.ReferencedTableSchema, + field.Constraint.ReferencedTableName, + field.Constraint.ReferencedColumnName, + 100, field.DataType) + if err != nil { + log.Printf("cannot get samples for field %q: %s\n", field.ColumnName, err) + continue + } + values = append(values, getters.NewRandomSample(field.ColumnName, samples, field.IsNullable)) + continue + } + maxValue := maxValues["bigint"] + if m, ok := maxValues[field.DataType]; ok { + maxValue = m + } + switch field.DataType { + case "tinyint": + values = append(values, getters.NewRandomIntRange(field.ColumnName, 0, 1, field.IsNullable)) + case "smallint", "mediumint", "int", "integer", "bigint": + values = append(values, getters.NewRandomInt(field.ColumnName, maxValue, field.IsNullable)) + case "float", "decimal", "double": + values = append(values, getters.NewRandomDecimal(field.ColumnName, + field.NumericPrecision.Int64-field.NumericScale.Int64, field.IsNullable)) + case "char", "varchar": + values = append(values, getters.NewRandomString(field.ColumnName, + field.CharacterMaximumLength.Int64, field.IsNullable)) + case "date": + values = append(values, getters.NewRandomDate(field.ColumnName, field.IsNullable)) + case "datetime", "timestamp": + values = append(values, getters.NewRandomDateTime(field.ColumnName, field.IsNullable)) + case "tinyblob", "tinytext", "blob", "text", "mediumtext", "mediumblob", "longblob", "longtext": + values = append(values, getters.NewRandomString(field.ColumnName, + field.CharacterMaximumLength.Int64, field.IsNullable)) + case "time": + values = append(values, getters.NewRandomTime(field.IsNullable)) + case "year": + values = append(values, getters.NewRandomIntRange(field.ColumnName, int64(time.Now().Year()-1), + int64(time.Now().Year()), field.IsNullable)) + case "enum", "set": + values = append(values, getters.NewRandomEnum(field.SetEnumVals, field.IsNullable)) + case "binary", "varbinary": + values = append(values, getters.NewRandomBinary(field.ColumnName, field.CharacterMaximumLength.Int64, field.IsNullable)) + default: + log.Printf("cannot get field type: %s: %s\n", field.ColumnName, field.DataType) + } + } + + return values +} + +func getSamples(conn *sql.DB, schema, table, field string, samples int64, dataType string) ([]interface{}, error) { + var count int64 + var query string + + queryCount := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", schema, table) + if err := conn.QueryRow(queryCount).Scan(&count); err != nil { + return nil, fmt.Errorf("cannot get count for table %q: %s", table, err) + } + + if count < samples { + query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s`", field, schema, table) + } else { + query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE RAND() <= .3 LIMIT %d", + field, schema, table, samples) + } + + rows, err := conn.Query(query) + if err != nil { + return nil, fmt.Errorf("cannot get samples: %s, %s", query, err) + } + defer rows.Close() + + values := []interface{}{} + + for rows.Next() { + var err error + var val interface{} + + switch dataType { + case "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "year": + var v int64 + err = rows.Scan(&v) + val = v + case "char", "varchar", "blob", "text", "mediumtext", + "mediumblob", "longblob", "longtext": + var v string + err = rows.Scan(&v) + val = v + case "binary", "varbinary": + var v []rune + err = rows.Scan(&v) + val = v + case "float", "decimal", "double": + var v float64 + err = rows.Scan(&v) + val = v + case "date", "time", "datetime", "timestamp": + var v time.Time + err = rows.Scan(&v) + val = v + } + if err != nil { + return nil, fmt.Errorf("cannot scan sample: %s", err) + } + values = append(values, val) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("cannot get samples: %s", err) + } + return values, nil +} + +func backticks(val string) string { + if strings.HasPrefix(val, "`") && strings.HasSuffix(val, "`") { + return url.QueryEscape(val) + } + return "`" + url.QueryEscape(val) + "`" +} diff --git a/internal/insert/insert_data.go b/internal/insert/insert_data.go new file mode 100644 index 0000000..d04403a --- /dev/null +++ b/internal/insert/insert_data.go @@ -0,0 +1,18 @@ +package insert + +import "github.com/Percona-Lab/mysql_random_data_load/internal/getters" + +type insertValues []getters.Getter + +func (iv insertValues) String() string { + sep := "" + query := "(" + + for _, v := range iv { + query += sep + v.Quote() + sep = ", " + } + query += ")" + + return query +} diff --git a/internal/insert/insert_test.go b/internal/insert/insert_test.go new file mode 100644 index 0000000..e1ce015 --- /dev/null +++ b/internal/insert/insert_test.go @@ -0,0 +1,27 @@ +package insert + +import ( + "testing" + + "github.com/Percona-Lab/mysql_random_data_load/internal/tu" + "github.com/Percona-Lab/mysql_random_data_load/tableparser" + "github.com/stretchr/testify/assert" +) + +func TestBasic(t *testing.T) { + db := tu.GetMySQLConnection(t) + tu.LoadQueriesFromFile(t, "child.sql") + + table, err := tableparser.New(db, "test", "parent") + assert.NoError(t, err) + + i := New(db, table) + + n, err := i.DryRun(9, 5) + assert.NoError(t, err) + assert.Equal(t, int64(9), n) + + n, err = i.Run(9, 5) + assert.NoError(t, err) + assert.Equal(t, int64(9), n) +} diff --git a/internal/insert/testdata/child.sql b/internal/insert/testdata/child.sql new file mode 100644 index 0000000..b6802b3 --- /dev/null +++ b/internal/insert/testdata/child.sql @@ -0,0 +1,38 @@ +DROP DATABASE IF EXISTS test; +CREATE DATABASE test; + +CREATE TABLE `test`.`child` ( + `child_id` int NOT NULL AUTO_INCREMENT, + `user_name` varchar(45) NOT NULL, + `password` varchar(255) NOT NULL, + `parent_id` int NOT NULL, + `avatar` varchar(255) NOT NULL, + `total_balance` decimal(10,2) NOT NULL DEFAULT '0.00', + `available_cups` int NOT NULL DEFAULT '0', + `sold_cups` int NOT NULL DEFAULT '0', + `total_sales` decimal(10,2) NOT NULL DEFAULT '0.00', + `last_seen` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`child_id`), + UNIQUE KEY `username_idx` (`user_name`), + KEY `parent_id_isx` (`parent_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `test`.`parent` ( + `parent_id` int NOT NULL AUTO_INCREMENT, + `first_name` varchar(45) NOT NULL, + `last_name` varchar(45) NOT NULL, + `email` varchar(45) NOT NULL, + `opt_in` tinyint(1) NOT NULL, + `avatar` int NOT NULL, + `password` varchar(255) NOT NULL, + `registered_with_google` tinyint(1) DEFAULT NULL, + `registered_with_facebook` tinyint(1) DEFAULT NULL, + `recovery_token` varchar(255) DEFAULT NULL, + `verified` tinyint(1) DEFAULT NULL, + `accepted_terms` datetime DEFAULT NULL, + `zip` varchar(10) NOT NULL, + PRIMARY KEY (`parent_id`), + UNIQUE KEY `parent_email_idx` (`email`), + KEY `recovery_token_idx` (`recovery_token`), + KEY `zip` (`zip`) +) ENGINE=InnoDB AUTO_INCREMENT=41 DEFAULT CHARSET=utf8; diff --git a/internal/ptdsn/ptdsn.go b/internal/ptdsn/ptdsn.go new file mode 100644 index 0000000..bf75d3e --- /dev/null +++ b/internal/ptdsn/ptdsn.go @@ -0,0 +1,141 @@ +package ptdsn + +import ( + "fmt" + "os/user" + "strconv" + "strings" + + "github.com/pkg/errors" + "gopkg.in/ini.v1" +) + +type PTDSN struct { + Database string + Host string + Password string + Port int + Table string + User string + Protocol string + ConfigFile string +} + +const ( + defaultMySQLConfigSection = "client" + defaultConfigFile = "~/.my.cnf" +) + +func (d *PTDSN) String() string { + return fmt.Sprintf("%v:%v@%v(%v:%v)/%v", d.User, d.Password, d.Protocol, d.Host, d.Port, d.Database) +} + +// Parse parses the connection string and returns MySQL connection parameters struct. +func Parse(value string) (*PTDSN, error) { + d := &PTDSN{} + parts := strings.Split(value, ",") + + // First, try to parse the values from the config. Those values will be overridden by the other dsn params + for _, part := range parts { + m := strings.Split(part, "=") + key := m[0] + value := "" + if len(m) > 1 { + value = m[1] + } + if key == "F" { + if err := loadMySQLConfigFile(value, d); err != nil { + return nil, errors.Wrap(err, "cannot parse config file") + } + d.ConfigFile = value + } + } + + // If there was no F parameter in the dsn, try to load the default ~/.my.cnf + if d.ConfigFile == "" { + d.ConfigFile = defaultConfigFile + // Don't check for error because the config might not exist + loadMySQLConfigFile(d.ConfigFile, d) // nolint + } + + for _, part := range parts { + m := strings.Split(part, "=") + key := m[0] + value := "" + if len(m) > 1 { + value = m[1] + } + switch key { + case "D": + d.Database = value + case "h": + d.Host = value + if d.Host == "localhost" { + d.Protocol = "unix" + } else { + d.Protocol = "tcp" + } + case "p": + d.Password = value + case "P": + port, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, errors.Wrap(err, "invalid port") + } + d.Port = int(port) + case "t": + d.Table = value + case "u": + d.User = value + } + } + + if d.Protocol == "tcp" && d.Port == 0 { + d.Port = 3306 + } + + return d, nil +} + +func loadMySQLConfigFile(filename string, d *PTDSN) error { + cfg, err := ini.Load(expandHomeDir(filename)) + if err != nil { + return err + } + + section := cfg.Section(defaultMySQLConfigSection) + + if section.HasKey("host") { + d.Host = section.Key("host").String() + } + + if section.HasKey("port") { + portstr := section.Key("port").String() + port, err := strconv.Atoi(portstr) + if err != nil { + return errors.Wrap(err, "invalid port") + } + d.Port = port + } + + if section.HasKey("user") { + d.User = section.Key("user").String() + } + + if section.HasKey("password") { + d.Password = section.Key("password").String() + } + + return nil +} + +func expandHomeDir(dir string) string { + if !strings.HasPrefix(dir, "~") { + return dir + } + u, err := user.Current() + if err != nil { + return dir + } + return u.HomeDir + strings.TrimPrefix(dir, "~") +} diff --git a/testutils/testutils.go b/internal/tu/tu.go similarity index 98% rename from testutils/testutils.go rename to internal/tu/tu.go index 554a256..db1a87c 100644 --- a/testutils/testutils.go +++ b/internal/tu/tu.go @@ -1,4 +1,4 @@ -package testutils +package tu import ( "bufio" @@ -45,8 +45,7 @@ func GetMySQLConnection(tb testing.TB) *sql.DB { dsn := os.Getenv("TEST_DSN") if dsn == "" { - fmt.Printf("%s TEST_DSN environment variable is empty", caller()) - tb.FailNow() + dsn = "msandbox:msandbox@tcp(127.0.0.1:3306)/sakila" } // Parse the DSN in the env var and ensure it has parseTime & multiStatements enabled diff --git a/main.go b/main.go index 07362ae..1cff39b 100644 --- a/main.go +++ b/main.go @@ -1,73 +1,15 @@ package main import ( - "database/sql" - "fmt" - "net/url" - "os" - "os/user" - "runtime" - "strings" - "sync" - "time" + "log" - "github.com/Percona-Lab/mysql_random_data_load/internal/getters" - "github.com/Percona-Lab/mysql_random_data_load/tableparser" - "github.com/go-ini/ini" - "github.com/go-sql-driver/mysql" - "github.com/gosuri/uiprogress" - "github.com/kr/pretty" - - log "github.com/sirupsen/logrus" - kingpin "gopkg.in/alecthomas/kingpin.v2" + "github.com/Percona-Lab/mysql_random_data_load/cmd" + "github.com/alecthomas/kong" ) -type cliOptions struct { - app *kingpin.Application - - // Arguments - Schema *string - TableName *string - Rows *int - // Flags - BulkSize *int - ConfigFile *string - Debug *bool - Factor *float64 - Host *string - MaxRetries *int - MaxThreads *int - NoProgress *bool - Pass *string - Port *int - Print *bool - Samples *int64 - User *string - Version *bool -} - -type mysqlOptions struct { - Host string - Password string - Port int - Sock string - User string -} - var ( - opts *cliOptions - - validFunctions = []string{"int", "string", "date", "date_in_range"} - maxValues = map[string]int64{ - "tinyint": 0XF, - "smallint": 0xFF, - "mediumint": 0x7FFFF, - "int": 0x7FFFFFFF, - "integer": 0x7FFFFFFF, - "float": 0x7FFFFFFF, - "decimal": 0x7FFFFFFF, - "double": 0x7FFFFFFF, - "bigint": 0x7FFFFFFFFFFFFFFF, + cli struct { + Run cmd.RunCmd `cmd:"run" help:"Starts the insert process"` } Version = "0.0.0." @@ -77,572 +19,27 @@ var ( GoVersion = "1.9.2" ) -type getter interface { - Value() interface{} - Quote() string - String() string -} -type insertValues []getter -type insertFunction func(*sql.DB, string, chan int, chan bool, *sync.WaitGroup) - const ( defaultMySQLConfigSection = "client" - defaultConfigFile = "~/.my.cnf" - defaultBulkSize = 1000 ) func main() { - - opts, err := processCliParams() - if err != nil { - log.Fatal(err.Error()) - } - - if *opts.Version { - fmt.Printf("Version : %s\n", Version) - fmt.Printf("Commit : %s\n", Commit) - fmt.Printf("Branch : %s\n", Branch) - fmt.Printf("Build : %s\n", Build) - fmt.Printf("Go version: %s\n", GoVersion) - return - } - - address := *opts.Host - net := "unix" - if address != "localhost" { - net = "tcp" - } - if *opts.Port != 0 { - address = fmt.Sprintf("%s:%d", address, *opts.Port) - } - - dsn := mysql.Config{ - User: *opts.User, - Passwd: *opts.Pass, - Addr: address, - Net: net, - DBName: "", - ParseTime: true, - AllowNativePasswords: true, - } - - db, err := sql.Open("mysql", dsn.FormatDSN()) - if err != nil { - panic(err) - } - db.SetMaxOpenConns(100) - - // SET TimeZone to UTC to avoid errors due to random dates & daylight saving valid values - if _, err = db.Exec(`SET @@session.time_zone = "+00:00"`); err != nil { - log.Printf("Cannot set time zone to UTC: %s\n", err) - db.Close() - os.Exit(1) - } - - table, err := tableparser.NewTable(db, *opts.Schema, *opts.TableName) - if err != nil { - log.Printf("cannot get table %s struct: %s", *opts.TableName, err) - db.Close() - os.Exit(1) - } - - log.SetFormatter(&log.TextFormatter{FullTimestamp: true}) - if *opts.Debug { - log.SetLevel(log.DebugLevel) - *opts.NoProgress = true - } - log.Debug(pretty.Sprint(table)) - - if len(table.Triggers) > 0 { - log.Warnf("There are triggers on the %s table that might affect this process:", *opts.TableName) - for _, t := range table.Triggers { - log.Warnf("Trigger %q, %s %s", t.Trigger, t.Timing, t.Event) - log.Warnf("Statement: %s", t.Statement) - } - } - - if *opts.Rows < 1 { - db.Close() // golint:noerror - log.Warnf("Number of rows < 1. There is nothing to do. Exiting") - os.Exit(1) - } - - if *opts.BulkSize > *opts.Rows { - *opts.BulkSize = *opts.Rows - } - if *opts.BulkSize < 1 { - *opts.BulkSize = defaultBulkSize - } - - if opts.MaxThreads == nil { - *opts.MaxThreads = runtime.NumCPU() * 10 - } - - if *opts.MaxThreads < 1 { - *opts.MaxThreads = 1 - } - - if !*opts.Print { - log.Info("Starting") - } - - // Example: want 11 rows with bulksize 4: - // count = int(11 / 4) = 2 -> 2 bulk inserts having 4 rows each = 8 rows - // We need to run this insert twice: - // INSERT INTO table (f1, f2) VALUES (?, ?), (?, ?), (?, ?), (?, ?) - // remainder = rows - count = 11 - 8 = 3 - // And then, we need to run this insert once to complete 11 rows - // INSERT INTO table (f1, f2) VALUES (?, ?), (?, ?), (?, ?) - newLineOnEachRow := false - count := *opts.Rows / *opts.BulkSize - remainder := *opts.Rows - count**opts.BulkSize - semaphores := makeSemaphores(*opts.MaxThreads) - rowValues := makeValueFuncs(db, table.Fields) - log.Debugf("Must run %d bulk inserts having %d rows each", count, *opts.BulkSize) - - runInsertFunc := runInsert - if *opts.Print { - *opts.MaxThreads = 1 - *opts.NoProgress = true - newLineOnEachRow = true - runInsertFunc = func(db *sql.DB, insertQuery string, resultsChan chan int, sem chan bool, wg *sync.WaitGroup) { - fmt.Println(insertQuery) - resultsChan <- *opts.BulkSize - sem <- true - wg.Done() - } - } - - bar := uiprogress.AddBar(*opts.Rows).AppendCompleted().PrependElapsed() - if !*opts.NoProgress { - uiprogress.Start() - } - - okCount, err := run(db, table, bar, semaphores, rowValues, count, *opts.BulkSize, runInsertFunc, newLineOnEachRow) - if err != nil { - log.Errorln(err) - } - var okrCount, okiCount int // remainder & individual inserts OK count - if remainder > 0 { - log.Debugf("Must run 1 extra bulk insert having %d rows, to complete %d rows", remainder, *opts.Rows) - okrCount, err = run(db, table, bar, semaphores, rowValues, 1, remainder, runInsertFunc, newLineOnEachRow) - if err != nil { - log.Errorln(err) - } - } - - // If there were errors and at this point we have less rows than *rows, - // retry adding individual rows (no bulk inserts) - totalOkCount := okCount + okrCount - retries := 0 - if totalOkCount < *opts.Rows { - log.Debugf("Running extra %d individual inserts (duplicated keys?)", *opts.Rows-totalOkCount) - } - for totalOkCount < *opts.Rows && retries < *opts.MaxRetries { - okiCount, err = run(db, table, bar, semaphores, rowValues, *opts.Rows-totalOkCount, 1, runInsertFunc, newLineOnEachRow) - if err != nil { - log.Errorf("Cannot run extra insert: %s", err) - } - - retries++ - totalOkCount += okiCount - } - - time.Sleep(500 * time.Millisecond) // Let the progress bar to update - if !*opts.Print { - log.Printf("%d rows inserted", totalOkCount) - } - db.Close() -} - -func run(db *sql.DB, table *tableparser.Table, bar *uiprogress.Bar, sem chan bool, - rowValues insertValues, count, bulkSize int, insertFunc insertFunction, newLineOnEachRow bool) (int, error) { - if count == 0 { - return 0, nil - } - var wg sync.WaitGroup - insertQuery := generateInsertStmt(table) - rowsChan := make(chan []getter, 1000) - okRowsChan := countRowsOK(count, bar) - - go generateInsertData(count*bulkSize, rowValues, rowsChan) - defaultSeparator1 := "" - if newLineOnEachRow { - defaultSeparator1 = "\n" - } - - i := 0 - rowsCount := 0 - sep1, sep2 := defaultSeparator1, "" - - for i < count { - rowData := <-rowsChan - rowsCount++ - insertQuery += sep1 + " (" - for _, field := range rowData { - insertQuery += sep2 + field.Quote() - sep2 = ", " - } - insertQuery += ")" - sep1 = ", " - if newLineOnEachRow { - sep1 += "\n" - } - sep2 = "" - if rowsCount < bulkSize { - continue - } - - insertQuery += ";\n" - <-sem - wg.Add(1) - go insertFunc(db, insertQuery, okRowsChan, sem, &wg) - - insertQuery = generateInsertStmt(table) - sep1, sep2 = defaultSeparator1, "" - rowsCount = 0 - i++ - } - - wg.Wait() - okCount := <-okRowsChan - return okCount, nil -} - -func makeSemaphores(count int) chan bool { - sem := make(chan bool, count) - for i := 0; i < count; i++ { - sem <- true - } - return sem -} - -// This go-routine keeps track of how many rows were actually inserted -// by the bulk inserts since one or more rows could generate duplicated -// keys so, not allways the number of inserted rows = number of rows in -// the bulk insert - -func countRowsOK(count int, bar *uiprogress.Bar) chan int { - var totalOk int - resultsChan := make(chan int, 10000) - go func() { - for i := 0; i < count; i++ { - okCount := <-resultsChan - for j := 0; j < okCount; j++ { - bar.Incr() - } - totalOk += okCount - } - resultsChan <- totalOk - }() - return resultsChan -} - -// generateInsertData will generate 'rows' items, where each item in the channel has 'bulkSize' rows. -// For example: -// We need to load 6 rows using a bulk insert having 2 rows per insert, like this: -// INSERT INTO table (f1, f2, f3) VALUES (?, ?, ?), (?, ?, ?) -// -// This function will put into rowsChan 3 elements, each one having the values for 2 rows: -// rowsChan <- [ v1-1, v1-2, v1-3, v2-1, v2-2, v2-3 ] -// rowsChan <- [ v3-1, v3-2, v3-3, v4-1, v4-2, v4-3 ] -// rowsChan <- [ v1-5, v5-2, v5-3, v6-1, v6-2, v6-3 ] -// -func generateInsertData(count int, values insertValues, rowsChan chan []getter) { - for i := 0; i < count; i++ { - insertRow := make([]getter, 0, len(values)) - for _, val := range values { - insertRow = append(insertRow, val) - } - rowsChan <- insertRow - } -} - -func generateInsertStmt(table *tableparser.Table) string { - fields := getFieldNames(table.Fields) - query := fmt.Sprintf("INSERT IGNORE INTO %s.%s (%s) VALUES ", - backticks(table.Schema), - backticks(table.Name), - strings.Join(fields, ","), + ctx := kong.Parse(&cli, + kong.Name("MySQL random data loader"), + kong.Description("Load random data into a MySQL table"), + kong.UsageOnError(), + kong.ConfigureHelp(kong.HelpOptions{ + Compact: false, + Summary: true, + Tree: true, + }), ) - return query -} - -func runInsert(db *sql.DB, insertQuery string, resultsChan chan int, sem chan bool, wg *sync.WaitGroup) { - result, err := db.Exec(insertQuery) - if err != nil { - log.Debugf("Cannot run insert: %s", err) - resultsChan <- 0 - sem <- true - wg.Done() - return - } - - rowsAffected, err := result.RowsAffected() - if err != nil { - log.Errorf("Cannot get rows affected after insert: %s", err) - } - resultsChan <- int(rowsAffected) - sem <- true - wg.Done() -} - -// makeValueFuncs returns an array of functions to generate all the values needed for a single row -func makeValueFuncs(conn *sql.DB, fields []tableparser.Field) insertValues { - var values []getter - for _, field := range fields { - if !field.IsNullable && field.ColumnKey == "PRI" && strings.Contains(field.Extra, "auto_increment") { - continue - } - if field.Constraint != nil { - samples, err := getSamples(conn, field.Constraint.ReferencedTableSchema, - field.Constraint.ReferencedTableName, - field.Constraint.ReferencedColumnName, - 100, field.DataType) - if err != nil { - log.Printf("cannot get samples for field %q: %s\n", field.ColumnName, err) - continue - } - values = append(values, getters.NewRandomSample(field.ColumnName, samples, field.IsNullable)) - continue - } - maxValue := maxValues["bigint"] - if m, ok := maxValues[field.DataType]; ok { - maxValue = m - } - switch field.DataType { - case "tinyint", "smallint", "mediumint", "int", "integer", "bigint": - values = append(values, getters.NewRandomInt(field.ColumnName, maxValue, field.IsNullable)) - case "float", "decimal", "double": - values = append(values, getters.NewRandomDecimal(field.ColumnName, - field.NumericPrecision.Int64-field.NumericScale.Int64, field.IsNullable)) - case "char", "varchar": - values = append(values, getters.NewRandomString(field.ColumnName, - field.CharacterMaximumLength.Int64, field.IsNullable)) - case "date": - values = append(values, getters.NewRandomDate(field.ColumnName, field.IsNullable)) - case "datetime", "timestamp": - values = append(values, getters.NewRandomDateTime(field.ColumnName, field.IsNullable)) - case "tinyblob", "tinytext", "blob", "text", "mediumtext", "mediumblob", "longblob", "longtext": - values = append(values, getters.NewRandomString(field.ColumnName, - field.CharacterMaximumLength.Int64, field.IsNullable)) - case "time": - values = append(values, getters.NewRandomTime(field.IsNullable)) - case "year": - values = append(values, getters.NewRandomIntRange(field.ColumnName, int64(time.Now().Year()-1), - int64(time.Now().Year()), field.IsNullable)) - case "enum", "set": - values = append(values, getters.NewRandomEnum(field.SetEnumVals, field.IsNullable)) - case "binary", "varbinary": - values = append(values, getters.NewRandomBinary(field.ColumnName, field.CharacterMaximumLength.Int64, field.IsNullable)) - default: - log.Printf("cannot get field type: %s: %s\n", field.ColumnName, field.DataType) - } - } - - return values -} - -func getFieldNames(fields []tableparser.Field) []string { - var fieldNames []string - for _, field := range fields { - if !isSupportedType(field.DataType) { - continue - } - if !field.IsNullable && field.ColumnKey == "PRI" && - strings.Contains(field.Extra, "auto_increment") { - continue - } - fieldNames = append(fieldNames, backticks(field.ColumnName)) - } - return fieldNames -} - -func getSamples(conn *sql.DB, schema, table, field string, samples int64, dataType string) ([]interface{}, error) { - var count int64 - var query string - - queryCount := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", schema, table) - if err := conn.QueryRow(queryCount).Scan(&count); err != nil { - return nil, fmt.Errorf("cannot get count for table %q: %s", table, err) - } - - if count < samples { - query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s`", field, schema, table) - } else { - query = fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE RAND() <= .3 LIMIT %d", - field, schema, table, samples) - } - - rows, err := conn.Query(query) - if err != nil { - return nil, fmt.Errorf("cannot get samples: %s, %s", query, err) - } - defer rows.Close() - - values := []interface{}{} - - for rows.Next() { - var err error - var val interface{} - - switch dataType { - case "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "year": - var v int64 - err = rows.Scan(&v) - val = v - case "char", "varchar", "blob", "text", "mediumtext", - "mediumblob", "longblob", "longtext": - var v string - err = rows.Scan(&v) - val = v - case "binary", "varbinary": - var v []rune - err = rows.Scan(&v) - val = v - case "float", "decimal", "double": - var v float64 - err = rows.Scan(&v) - val = v - case "date", "time", "datetime", "timestamp": - var v time.Time - err = rows.Scan(&v) - val = v - } - if err != nil { - return nil, fmt.Errorf("cannot scan sample: %s", err) + switch ctx.Command() { + case "run": + if err := ctx.Run(); err != nil { + log.Fatalf(err.Error()) } - values = append(values, val) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("cannot get samples: %s", err) - } - return values, nil -} - -func backticks(val string) string { - if strings.HasPrefix(val, "`") && strings.HasSuffix(val, "`") { - return url.QueryEscape(val) - } - return "`" + url.QueryEscape(val) + "`" -} - -func isSupportedType(fieldType string) bool { - supportedTypes := map[string]bool{ - "tinyint": true, - "smallint": true, - "mediumint": true, - "int": true, - "integer": true, - "bigint": true, - "float": true, - "decimal": true, - "double": true, - "char": true, - "varchar": true, - "date": true, - "datetime": true, - "timestamp": true, - "time": true, - "year": true, - "tinyblob": true, - "tinytext": true, - "blob": true, - "text": true, - "mediumblob": true, - "mediumtext": true, - "longblob": true, - "longtext": true, - "binary": true, - "varbinary": true, - "enum": true, - "set": true, - } - _, ok := supportedTypes[fieldType] - return ok -} - -func processCliParams() (*cliOptions, error) { - app := kingpin.New("mysql_random_data_loader", "MySQL Random Data Loader") - - opts := &cliOptions{ - app: app, - BulkSize: app.Flag("bulk-size", "Number of rows per insert statement").Default(fmt.Sprintf("%d", defaultBulkSize)).Int(), - ConfigFile: app.Flag("config-file", "MySQL config file").Default(expandHomeDir(defaultConfigFile)).String(), - Debug: app.Flag("debug", "Log debugging information").Bool(), - Factor: app.Flag("fk-samples-factor", "Percentage used to get random samples for foreign keys fields").Default("0.3").Float64(), - Host: app.Flag("host", "Host name/IP").Short('h').String(), - MaxRetries: app.Flag("max-retries", "Number of rows to insert").Default("100").Int(), - MaxThreads: app.Flag("max-threads", "Maximum number of threads to run inserts").Default("1").Int(), - NoProgress: app.Flag("no-progress", "Show progress bar").Default("false").Bool(), - Pass: app.Flag("password", "Password").Short('p').String(), - Port: app.Flag("port", "Port").Short('P').Int(), - Print: app.Flag("print", "Print queries to the standard output instead of inserting them into the db").Bool(), - Samples: app.Flag("max-fk-samples", "Maximum number of samples for foreign keys fields").Default("100").Int64(), - User: app.Flag("user", "User").Short('u').String(), - Version: app.Flag("version", "Show version and exit").Bool(), - - Schema: app.Arg("database", "Database").Required().String(), - TableName: app.Arg("table", "Table").Required().String(), - Rows: app.Arg("rows", "Number of rows to insert").Required().Int(), - } - _, err := app.Parse(os.Args[1:]) - - if err != nil { - return nil, err - } - - if mysqlOpts, err := readMySQLConfigFile(*opts.ConfigFile); err == nil { - checkMySQLParams(opts, mysqlOpts) - } - - return opts, nil -} - -func checkMySQLParams(opts *cliOptions, mysqlOpts *mysqlOptions) { - if *opts.Host == "" && mysqlOpts.Host != "" { - *opts.Host = mysqlOpts.Host - } - - if *opts.Port == 0 && mysqlOpts.Port != 0 { - *opts.Port = mysqlOpts.Port - } - - if *opts.User == "" && mysqlOpts.User != "" { - *opts.User = mysqlOpts.User - } - - if *opts.Pass == "" && mysqlOpts.Password != "" { - *opts.Pass = mysqlOpts.Password - } -} - -func readMySQLConfigFile(filename string) (*mysqlOptions, error) { - cfg, err := ini.Load(expandHomeDir(filename)) - if err != nil { - return nil, err - } - - section := cfg.Section(defaultMySQLConfigSection) - port, _ := section.Key("port").Int() - - mysqlOpts := &mysqlOptions{ - Host: section.Key("host").String(), - Port: port, - User: section.Key("user").String(), - Password: section.Key("password").String(), - } - - return mysqlOpts, nil -} - -func expandHomeDir(dir string) string { - if !strings.HasPrefix(dir, "~") { - return dir - } - u, err := user.Current() - if err != nil { - return dir + default: + log.Fatalf("Unknown command") } - return u.HomeDir + strings.TrimPrefix(dir, "~") } diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 7f6c601..0000000 --- a/main_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package main - -import ( - "fmt" - "reflect" - "sync" - "testing" - "time" - - "github.com/Percona-Lab/mysql_random_data_load/internal/getters" - "github.com/Percona-Lab/mysql_random_data_load/tableparser" - tu "github.com/Percona-Lab/mysql_random_data_load/testutils" -) - -func TestGetSamples(t *testing.T) { - conn := tu.GetMySQLConnection(t) - var wantRows int64 = 100 - samples, err := getSamples(conn, "sakila", "inventory", "inventory_id", wantRows, "int") - tu.Ok(t, err, "error getting samples") - _, ok := samples[0].(int64) - tu.Assert(t, ok, "Wrong data type.") - tu.Assert(t, int64(len(samples)) == wantRows, - "Wrong number of samples. Have %d, want 100.", len(samples)) -} - -func TestGenerateInsertData(t *testing.T) { - wantRows := 3 - - values := []getter{ - getters.NewRandomInt("f1", 100, false), - getters.NewRandomString("f2", 10, false), - getters.NewRandomDate("f3", false), - } - - rowsChan := make(chan []getter, 100) - count := 0 - wg := &sync.WaitGroup{} - wg.Add(1) - - go func() { - for { - select { - case <-time.After(10 * time.Millisecond): - wg.Done() - return - case row := <-rowsChan: - if reflect.TypeOf(row[0]).String() != "*getters.RandomInt" { - fmt.Printf("Expected '*getters.RandomInt' for field [0], got %q\n", reflect.TypeOf(row[0]).String()) - t.Fail() - } - if reflect.TypeOf(row[1]).String() != "*getters.RandomString" { - fmt.Printf("Expected '*getters.RandomString' for field [1], got %q\n", reflect.TypeOf(row[1]).String()) - t.Fail() - } - if reflect.TypeOf(row[2]).String() != "*getters.RandomDate" { - fmt.Printf("Expected '*getters.RandomDate' for field [2], got %q\n", reflect.TypeOf(row[2]).String()) - t.Fail() - } - count++ - } - } - }() - - generateInsertData(wantRows, values, rowsChan) - - wg.Wait() - tu.Assert(t, count == 3, "Invalid number of rows") -} - -func TestGenerateInsertStmt(t *testing.T) { - var table *tableparser.Table - tu.LoadJson(t, "sakila.film.json", &table) - want := "INSERT IGNORE INTO `sakila`.`film` " + - "(`title`,`description`,`release_year`,`language_id`," + - "`original_language_id`,`rental_duration`,`rental_rate`," + - "`length`,`replacement_cost`,`rating`,`special_features`," + - "`last_update`) VALUES " - - query := generateInsertStmt(table) - tu.Equals(t, want, query) -} diff --git a/tableparser/tableparser.go b/tableparser/tableparser.go index 0363b6c..53f371c 100644 --- a/tableparser/tableparser.go +++ b/tableparser/tableparser.go @@ -70,6 +70,7 @@ type IndexField struct { NonUnique bool Visible string // MySQL 8.0+ Expression sql.NullString // MySQL 8.0.16+ + Clustered string // TiDB Support } // Constraint holds Foreign Keys information @@ -83,7 +84,7 @@ type Constraint struct { // Field holds raw field information as defined in INFORMATION_SCHEMA type Field struct { - TableCatalog string + TableCatalog string `db:"TABLE_CATALOG"` TableSchema string TableName string ColumnName string @@ -124,7 +125,7 @@ type Trigger struct { DatabaseCollation string } -func NewTable(db *sql.DB, schema, tableName string) (*Table, error) { +func New(db *sql.DB, schema, tableName string) (*Table, error) { table := &Table{ Schema: url.QueryEscape(schema), Name: url.QueryEscape(tableName), @@ -136,7 +137,9 @@ func NewTable(db *sql.DB, schema, tableName string) (*Table, error) { if err != nil { return nil, err } + table.Constraints, err = getConstraints(db, table.Schema, table.Name) + if err != nil { return nil, err } @@ -159,8 +162,32 @@ func (t *Table) parse() error { // | | +---------- extra info (unsigned, etc) // | | | re := regexp.MustCompile(`^(.*?)(?:\((.*?)\)(.*))?$`) + fields := []string{ + "TABLE_CATALOG", + "TABLE_SCHEMA", + "TABLE_NAME", + "COLUMN_NAME", + "ORDINAL_POSITION", + "COLUMN_DEFAULT", + "IS_NULLABLE", + "DATA_TYPE", + "CHARACTER_MAXIMUM_LENGTH", + "CHARACTER_OCTET_LENGTH", + "NUMERIC_PRECISION", + "NUMERIC_SCALE", + "CHARACTER_SET_NAME", + "COLLATION_NAME", + "COLUMN_TYPE", + "COLUMN_KEY", + "EXTRA", + "PRIVILEGES", + "COLUMN_COMMENT", + } - query := "SELECT * FROM `information_schema`.`COLUMNS` WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION" + query := "SELECT " + strings.Join(fields, ",") + + " FROM `information_schema`.`COLUMNS` " + + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? " + + "ORDER BY ORDINAL_POSITION" constraints := constraintsAsMap(t.Constraints) @@ -226,19 +253,13 @@ func makeScanRecipients(f *Field, allowNull *string, cols []string) []interface{ &f.CharacterOctetLength, &f.NumericPrecision, &f.NumericScale, - } - - if len(cols) > 19 { // MySQL 5.5 does not have "DATETIME_PRECISION" field - fields = append(fields, &f.DatetimePrecision) - } - - fields = append(fields, &f.CharacterSetName, &f.CollationName, &f.ColumnType, &f.ColumnKey, &f.Extra, &f.Privileges, &f.ColumnComment) - - if len(cols) > 20 && cols[20] == "GENERATION_EXPRESSION" { // MySQL 5.7+ "GENERATION_EXPRESSION" field - fields = append(fields, &f.GenerationExpression) - } - if len(cols) > 21 && cols[21] == "SRS_ID" { // MySQL 8.0+ "SRS ID" field - fields = append(fields, &f.SrsID) + &f.CharacterSetName, + &f.CollationName, + &f.ColumnType, + &f.ColumnKey, + &f.Extra, + &f.Privileges, + &f.ColumnComment, } return fields @@ -277,6 +298,10 @@ func getIndexes(db *sql.DB, schema, tableName string) (map[string]Index, error) if err == nil && len(cols) >= 15 && cols[14] == "Expression" { fields = append(fields, &i.Expression) } + // support for TiDB (Clustered Index) + if err == nil && len(cols) >= 16 && cols[15] == "Clustered" { + fields = append(fields, &i.Clustered) + } err = rows.Scan(fields...) if err != nil { diff --git a/tableparser/tableparser_test.go b/tableparser/tableparser_test.go index 1f9572d..3a4abf8 100644 --- a/tableparser/tableparser_test.go +++ b/tableparser/tableparser_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - tu "github.com/Percona-Lab/mysql_random_data_load/testutils" + "github.com/Percona-Lab/mysql_random_data_load/internal/tu" _ "github.com/go-sql-driver/mysql" version "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" @@ -27,7 +27,7 @@ func TestParse(t *testing.T) { t.Fatalf("Unknown MySQL version %s", v.String()) } - table, err := NewTable(db, "sakila", "film") + table, err := New(db, "sakila", "film") if err != nil { t.Error(err) } diff --git a/tableparser/testdata/indexes.json b/tableparser/testdata/indexes.json index 6a03d3a..3f4703b 100644 --- a/tableparser/testdata/indexes.json +++ b/tableparser/testdata/indexes.json @@ -5,7 +5,8 @@ "actor_id" ], "Unique": true, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_fk_film_id": { "Name": "idx_fk_film_id", @@ -13,6 +14,7 @@ "film_id" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" } } \ No newline at end of file diff --git a/tableparser/testdata/table001.json b/tableparser/testdata/table001.json index a0fe07a..7ef56ab 100644 --- a/tableparser/testdata/table001.json +++ b/tableparser/testdata/table001.json @@ -691,7 +691,7 @@ }, "DatetimePrecision": { "Int64": 0, - "Valid": true + "Valid": false }, "CharacterSetName": { "String": "", @@ -722,7 +722,8 @@ "film_id" ], "Unique": true, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_fk_language_id": { "Name": "idx_fk_language_id", @@ -730,7 +731,8 @@ "language_id" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_fk_original_language_id": { "Name": "idx_fk_original_language_id", @@ -738,7 +740,8 @@ "original_language_id" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_title": { "Name": "idx_title", @@ -746,7 +749,8 @@ "title" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" } }, "Constraints": [ diff --git a/tableparser/testdata/table002.json b/tableparser/testdata/table002.json index a0fe07a..7ef56ab 100755 --- a/tableparser/testdata/table002.json +++ b/tableparser/testdata/table002.json @@ -691,7 +691,7 @@ }, "DatetimePrecision": { "Int64": 0, - "Valid": true + "Valid": false }, "CharacterSetName": { "String": "", @@ -722,7 +722,8 @@ "film_id" ], "Unique": true, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_fk_language_id": { "Name": "idx_fk_language_id", @@ -730,7 +731,8 @@ "language_id" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_fk_original_language_id": { "Name": "idx_fk_original_language_id", @@ -738,7 +740,8 @@ "original_language_id" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" }, "idx_title": { "Name": "idx_title", @@ -746,7 +749,8 @@ "title" ], "Unique": false, - "Visible": true + "Visible": true, + "Expression": "" } }, "Constraints": [ diff --git a/testdata/schema/0.user.sql b/testdata/schema/0.user.sql new file mode 100644 index 0000000..deabba6 --- /dev/null +++ b/testdata/schema/0.user.sql @@ -0,0 +1,2 @@ +GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' IDENTIFIED BY "root"; +GRANT ALL PRIVILEGES ON *.* TO 'msandbox'@'%' IDENTIFIED BY "msandbox";