Skip to content

Commit

Permalink
entimport: Add field exclusion and error readability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc committed Jul 19, 2022
1 parent 1d0105e commit 45c1f05
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 11 deletions.
7 changes: 6 additions & 1 deletion cmd/entimport/entimport.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ import (
"ariga.io/entimport/internal/mux"
)

var tablesFlag tables
var (
tablesFlag tables
excludeTablesFlag tables
)

func init() {
flag.Var(&tablesFlag, "tables", "comma-separated list of tables to inspect (all if empty)")
flag.Var(&excludeTablesFlag, "exclude-tables", "comma-separated list of tables to exclude")
}

func main() {
Expand All @@ -37,6 +41,7 @@ func main() {
}
i, err := entimport.NewImport(
entimport.WithTables(tablesFlag),
entimport.WithExcludedTables(excludeTablesFlag),
entimport.WithDriver(drv),
)
if err != nil {
Expand Down
23 changes: 17 additions & 6 deletions internal/entimport/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ type (

// ImportOptions are the options passed on to every SchemaImporter.
ImportOptions struct {
tables []string
schemaPath string
driver *mux.ImportDriver
tables []string
excludedTables []string
schemaPath string
driver *mux.ImportDriver
}

// ImportOption allows for managing import configuration using functional options.
Expand All @@ -71,6 +72,13 @@ func WithTables(tables []string) ImportOption {
}
}

// WithExcludedTables supplies the set of tables to exclude.
func WithExcludedTables(tables []string) ImportOption {
return func(i *ImportOptions) {
i.excludedTables = tables
}
}

// WithDriver provides an import driver to be used by SchemaImporter.
func WithDriver(drv *mux.ImportDriver) ImportOption {
return func(i *ImportOptions) {
Expand Down Expand Up @@ -234,8 +242,11 @@ func tableName(typeName string) string {

// resolvePrimaryKey returns the primary key as an ent field for a given table.
func resolvePrimaryKey(field fieldFunc, table *schema.Table) (f ent.Field, err error) {
if table.PrimaryKey == nil || len(table.PrimaryKey.Parts) != 1 {
return nil, fmt.Errorf("entimport: invalid primary key - single part key must be present")
if table.PrimaryKey == nil {
return nil, fmt.Errorf("entimport: missing primary key (table: %v)", table.Name)
}
if len(table.PrimaryKey.Parts) != 1 {
return nil, fmt.Errorf("entimport: invalid primary key - single part key must be present (table: %v, got: %v parts)", table.Name, len(table.PrimaryKey.Parts))
}
if f, err = field(table.PrimaryKey.Parts[0].C); err != nil {
return nil, err
Expand Down Expand Up @@ -324,7 +335,7 @@ func schemaMutations(field fieldFunc, tables []*schema.Table) ([]schemast.Mutato
}
node, err := upsertNode(field, table)
if err != nil {
return nil, err
return nil, fmt.Errorf("issue with table %v: %w", table.Name, err)
}
mutations[table.Name] = node
}
Expand Down
18 changes: 16 additions & 2 deletions internal/entimport/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,21 @@ func (m *MySQL) SchemaMutations(ctx context.Context) ([]schemast.Mutator, error)
if err != nil {
return nil, err
}
return schemaMutations(m.field, s.Tables)
var tables []*schema.Table
if m.excludedTables != nil {
excludedTableNames := make(map[string]bool)
for _, t := range m.excludedTables {
excludedTableNames[t] = true
}
// filter out tables that are in excludedTables:
for _, t := range s.Tables {
if !excludedTableNames[t.Name] {
tables = append(tables, t)
} else {
}
}
}
return schemaMutations(m.field, tables)
}

func (m *MySQL) field(column *schema.Column) (f ent.Field, err error) {
Expand All @@ -67,7 +81,7 @@ func (m *MySQL) field(column *schema.Column) (f ent.Field, err error) {
case *schema.TimeType:
f = field.Time(name)
default:
return nil, fmt.Errorf("entimport: unsupported type %q", typ)
return nil, fmt.Errorf("column %v: unsupported type %q", column.Name, typ)
}
applyColumnAttributes(f, column)
return f, err
Expand Down
18 changes: 16 additions & 2 deletions internal/entimport/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,21 @@ func (p *Postgres) SchemaMutations(ctx context.Context) ([]schemast.Mutator, err
if err != nil {
return nil, err
}
return schemaMutations(p.field, s.Tables)
var tables []*schema.Table
if p.excludedTables != nil {
excludedTableNames := make(map[string]bool)
for _, t := range p.excludedTables {
excludedTableNames[t] = true
}
// filter out tables that are in excludedTables:
for _, t := range s.Tables {
if !excludedTableNames[t.Name] {
tables = append(tables, t)
} else {
}
}
}
return schemaMutations(p.field, tables)
}

func (p *Postgres) field(column *schema.Column) (f ent.Field, err error) {
Expand Down Expand Up @@ -66,7 +80,7 @@ func (p *Postgres) field(column *schema.Column) (f ent.Field, err error) {
case *postgres.UUIDType:
f = field.UUID(name, uuid.New())
default:
return nil, fmt.Errorf("entimport: unsupported type %q", typ)
return nil, fmt.Errorf("column %v: unsupported type %q", column.Name, typ)
}
applyColumnAttributes(f, column)
return f, err
Expand Down

0 comments on commit 45c1f05

Please sign in to comment.