diff --git a/internal/cmd/database/dump.go b/internal/cmd/database/dump.go index 59be8d00..d6aa5a1f 100644 --- a/internal/cmd/database/dump.go +++ b/internal/cmd/database/dump.go @@ -34,6 +34,7 @@ type dumpFlags struct { rdonly bool tables string wheres string + columns []string output string threads int schemaOnly bool @@ -70,6 +71,8 @@ func DumpCmd(ch *cmdutil.Helper) *cobra.Command { cmd.PersistentFlags().BoolVar(&f.schemaOnly, "schema-only", false, "Only dump schema, skip table data.") cmd.PersistentFlags().StringVar(&f.outputFormat, "output-format", "sql", "Output format for data: sql (for MySQL, default), json, or csv.") + cmd.PersistentFlags().StringArrayVar(&f.columns, "columns", nil, + "Columns to include for specific tables (format: 'table:col1,col2'). Can be specified multiple times for different tables.") return cmd } @@ -285,6 +288,14 @@ func dump(ch *cmdutil.Helper, cmd *cobra.Command, flags *dumpFlags, args []strin } } + if len(flags.columns) > 0 { + includes, err := parseColumnIncludes(flags.columns) + if err != nil { + return fmt.Errorf("invalid --columns: %w", err) + } + cfg.ColumnIncludes = includes + } + d, err := dumper.NewDumper(cfg) if err != nil { return err @@ -346,3 +357,38 @@ func getDatabaseName(name, addr string) (string, error) { return "", fmt.Errorf("could not find a valid database name for this database: %v", dbs) } + +// parseColumnIncludes parses --columns flags into a map of table name -> column names to include. +func parseColumnIncludes(columns []string) (map[string]map[string]bool, error) { + result := make(map[string]map[string]bool) + + for _, colSpec := range columns { + table, colList, found := strings.Cut(colSpec, ":") + if !found { + return nil, fmt.Errorf("invalid column spec %q: expected 'table:col1,col2' format", colSpec) + } + table = strings.TrimSpace(table) + if table == "" { + return nil, fmt.Errorf("invalid column spec %q: table name cannot be empty", colSpec) + } + + if result[table] == nil { + result[table] = make(map[string]bool) + } + + cols := strings.Split(colList, ",") + for _, col := range cols { + col = strings.TrimSpace(col) + if col == "" { + continue + } + result[table][col] = true + } + + if len(result[table]) == 0 { + return nil, fmt.Errorf("invalid column spec %q: at least one column must be specified", colSpec) + } + } + + return result, nil +} diff --git a/internal/cmd/database/dump_test.go b/internal/cmd/database/dump_test.go new file mode 100644 index 00000000..e4ae9079 --- /dev/null +++ b/internal/cmd/database/dump_test.go @@ -0,0 +1,102 @@ +package database + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestParseColumnIncludes(t *testing.T) { + c := qt.New(t) + + tests := []struct { + name string + columns []string + want map[string]map[string]bool + wantErr bool + errContains string + }{ + { + name: "single table single column", + columns: []string{"users:id"}, + want: map[string]map[string]bool{ + "users": {"id": true}, + }, + }, + { + name: "single table multiple columns", + columns: []string{"users:id,name,email"}, + want: map[string]map[string]bool{ + "users": {"id": true, "name": true, "email": true}, + }, + }, + { + name: "multiple tables", + columns: []string{"users:id,name", "orders:id,total"}, + want: map[string]map[string]bool{ + "users": {"id": true, "name": true}, + "orders": {"id": true, "total": true}, + }, + }, + { + name: "columns with whitespace", + columns: []string{"users: id , name , email "}, + want: map[string]map[string]bool{ + "users": {"id": true, "name": true, "email": true}, + }, + }, + { + name: "same table specified multiple times merges columns", + columns: []string{"users:id", "users:name,email"}, + want: map[string]map[string]bool{ + "users": {"id": true, "name": true, "email": true}, + }, + }, + { + name: "empty input", + columns: []string{}, + want: map[string]map[string]bool{}, + }, + { + name: "missing colon", + columns: []string{"users-id,name"}, + wantErr: true, + errContains: "expected 'table:col1,col2' format", + }, + { + name: "empty table name", + columns: []string{":id,name"}, + wantErr: true, + errContains: "table name cannot be empty", + }, + { + name: "empty column list", + columns: []string{"users:"}, + wantErr: true, + errContains: "at least one column must be specified", + }, + { + name: "only whitespace columns", + columns: []string{"users: , , "}, + wantErr: true, + errContains: "at least one column must be specified", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseColumnIncludes(tt.columns) + + if tt.wantErr { + c.Assert(err, qt.IsNotNil) + if tt.errContains != "" { + c.Assert(err.Error(), qt.Contains, tt.errContains) + } + return + } + + c.Assert(err, qt.IsNil) + c.Assert(got, qt.DeepEquals, tt.want) + }) + } +} diff --git a/internal/dumper/dumper.go b/internal/dumper/dumper.go index 4ee52fa4..ea9cb145 100644 --- a/internal/dumper/dumper.go +++ b/internal/dumper/dumper.go @@ -55,6 +55,7 @@ type Config struct { Wheres map[string]string Selects map[string]map[string]string Filters map[string]map[string]string + ColumnIncludes map[string]map[string]bool // Interval in millisecond. IntervalMs int @@ -366,9 +367,18 @@ func (d *Dumper) tableDumpContext(conn *Connection, table string) (*dumpContext, ctx.fieldNames = make([]string, 0) ctx.selfields = make([]string, 0) + // Check if we have column inclusion filters for this table + includeFilter := d.cfg.ColumnIncludes[table] + hasIncludeFilter := len(includeFilter) > 0 + for _, name := range flds { d.log.Debug("dump", zap.Any("filters", d.cfg.Filters), zap.String("table", table), zap.String("field_name", name)) + // If include filter is specified, only include listed columns + if hasIncludeFilter && !includeFilter[name] { + continue + } + if _, ok := d.cfg.Filters[table][name]; ok { continue } @@ -382,6 +392,11 @@ func (d *Dumper) tableDumpContext(conn *Connection, table string) (*dumpContext, } } + // Validate that at least one column was included when using include filter + if hasIncludeFilter && len(ctx.fieldNames) == 0 { + return nil, fmt.Errorf("no valid columns found for table %q with column filter (available columns: %v)", table, flds) + } + if v, ok := d.cfg.Wheres[table]; ok { ctx.where = fmt.Sprintf(" WHERE %v", v) } diff --git a/internal/dumper/dumper_test.go b/internal/dumper/dumper_test.go index e9232117..5ae7d2f3 100644 --- a/internal/dumper/dumper_test.go +++ b/internal/dumper/dumper_test.go @@ -1574,6 +1574,131 @@ func TestEscapeBytes(t *testing.T) { } } +func TestDumperColumnIncludes(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.INFO)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + c.Cleanup(func() { server.Close() }) + + address := server.Addr() + + // Result with only id and name columns (filtered from full set) + selectResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "id", Type: querypb.Type_INT32}, + {Name: "name", Type: querypb.Type_VARCHAR}, + }, + Rows: make([][]sqltypes.Value, 0, 256), + } + + for i := 0; i < 100; i++ { + row := []sqltypes.Value{ + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("42")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("testuser")), + } + selectResult.Rows = append(selectResult.Rows, row) + } + + schemaResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "Table", Type: querypb.Type_VARCHAR}, + {Name: "Create Table", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("users")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, + []byte("CREATE TABLE `users` (`id` int, `name` varchar(255), `email` varchar(255), `password` varchar(255)) ENGINE=InnoDB")), + }, + }, + } + + tablesResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "Tables_in_test", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + {sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("users"))}, + }, + } + + viewsResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "TABLE_NAME", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{}, + } + + // Fields result includes all columns, but only id and name should be dumped + fieldsResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "Field", Type: querypb.Type_VARCHAR}, + {Name: "Type", Type: querypb.Type_VARCHAR}, + {Name: "Null", Type: querypb.Type_VARCHAR}, + {Name: "Key", Type: querypb.Type_VARCHAR}, + {Name: "Default", Type: querypb.Type_VARCHAR}, + {Name: "Extra", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + testRow("id", ""), + testRow("name", ""), + testRow("email", ""), + testRow("password", ""), + }, + } + + // fakedbs setup + { + fakedbs.AddQueryPattern("use .*", &sqltypes.Result{}) + fakedbs.AddQueryPattern("show create table .*", schemaResult) + fakedbs.AddQueryPattern("show tables from .*", tablesResult) + fakedbs.AddQueryPattern("select table_name \n\t\t\t from information_schema.tables \n\t\t\t where table_schema like 'test' \n\t\t\t and table_type = 'view'\n\t\t\t", viewsResult) + fakedbs.AddQueryPattern("show fields from .*", fieldsResult) + // The SELECT should only include id and name columns + fakedbs.AddQueryPattern("select `id`, `name` from `test`\\.`users` .*", selectResult) + fakedbs.AddQueryPattern("set .*", &sqltypes.Result{}) + } + + cfg := &Config{ + Database: "test", + Table: "users", + Outdir: c.TempDir(), + User: "mock", + Password: "mock", + Address: address, + ChunksizeInMB: 1, + Threads: 16, + StmtSize: 10000, + IntervalMs: 500, + // Only include id and name columns for the users table + ColumnIncludes: map[string]map[string]bool{ + "users": { + "id": true, + "name": true, + }, + }, + } + + d, err := NewDumper(cfg) + c.Assert(err, qt.IsNil) + + err = d.Run(context.Background()) + c.Assert(err, qt.IsNil) + + // Verify the output contains only the specified columns + dat, err := os.ReadFile(cfg.Outdir + "/test.users.00001.sql") + c.Assert(err, qt.IsNil) + + // Should have INSERT with only id and name + c.Assert(string(dat), qt.Contains, "INSERT INTO `users`(`id`,`name`)") + // Should NOT have email or password columns + c.Assert(strings.Contains(string(dat), "email"), qt.IsFalse) + c.Assert(strings.Contains(string(dat), "password"), qt.IsFalse) +} + func TestDumper_OutputFormats(t *testing.T) { const numTestRows = 100