Skip to content

Commit

Permalink
fix: [#280] add wrap for sql
Browse files Browse the repository at this point in the history
  • Loading branch information
hwbrzzl committed Nov 9, 2024
1 parent 7674cb5 commit e4e70d3
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 164 deletions.
2 changes: 0 additions & 2 deletions contracts/database/schema/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ type Grammar interface {
CompileViews() string
// GetAttributeCommands Get the commands for the schema build.
GetAttributeCommands() []string
// GetModifiers Get the column modifiers.
GetModifiers() []func(Blueprint, ColumnDefinition) string
// TypeBigInteger Create the column definition for a big integer type.
TypeBigInteger(column ColumnDefinition) string
// TypeInteger Create the column definition for an integer type.
Expand Down
18 changes: 14 additions & 4 deletions database/schema/blueprint.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package schema

import (
"fmt"
"strings"

ormcontract "github.com/goravel/framework/contracts/database/orm"
Expand Down Expand Up @@ -85,7 +86,7 @@ func (r *Blueprint) GetPrefix() string {

func (r *Blueprint) GetTableName() string {
// TODO Add schema for Postgres
return r.prefix + r.table
return r.table
}

func (r *Blueprint) HasCommand(command string) bool {
Expand Down Expand Up @@ -209,11 +210,20 @@ func (r *Blueprint) addImpliedCommands(grammar schema.Grammar) {
}

func (r *Blueprint) createIndexName(ttype string, columns []string) string {
table := r.GetTableName()
index := strings.ToLower(table + "_" + strings.Join(columns, "_") + "_" + ttype)
var table string
if strings.Contains(r.table, ".") {
lastDotIndex := strings.LastIndex(r.table, ".")
table = r.table[:lastDotIndex+1] + r.prefix + r.table[lastDotIndex+1:]
} else {
table = r.prefix + r.table
}

index := strings.ToLower(fmt.Sprintf("%s_%s_%s", table, strings.Join(columns, "_"), ttype))

index = strings.ReplaceAll(index, "-", "_")
index = strings.ReplaceAll(index, ".", "_")

return strings.ReplaceAll(index, ".", "_")
return index
}

func (r *Blueprint) indexCommand(ttype string, columns []string, config ...schema.IndexConfig) *schema.Command {
Expand Down
11 changes: 5 additions & 6 deletions database/schema/blueprint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type BlueprintTestSuite struct {
func TestBlueprintTestSuite(t *testing.T) {
suite.Run(t, &BlueprintTestSuite{
grammars: map[database.Driver]schema.Grammar{
database.DriverPostgres: grammars.NewPostgres(),
database.DriverPostgres: grammars.NewPostgres("goravel_"),
},
})
}
Expand Down Expand Up @@ -139,6 +139,10 @@ func (s *BlueprintTestSuite) TestBuild() {
func (s *BlueprintTestSuite) TestCreateIndexName() {
name := s.blueprint.createIndexName("index", []string{"id", "name-1", "name.2"})
s.Equal("goravel_users_id_name_1_name_2_index", name)

s.blueprint.table = "public.users"
name = s.blueprint.createIndexName("index", []string{"id", "name-1", "name.2"})
s.Equal("public_goravel_users_id_name_1_name_2_index", name)
}

func (s *BlueprintTestSuite) TestGetAddedColumns() {
Expand All @@ -153,11 +157,6 @@ func (s *BlueprintTestSuite) TestGetAddedColumns() {
s.Equal(addedColumn, s.blueprint.GetAddedColumns()[0])
}

func (s *BlueprintTestSuite) TestGetTableName() {
s.blueprint.SetTable("users")
s.Equal("goravel_users", s.blueprint.GetTableName())
}

func (s *BlueprintTestSuite) TestHasCommand() {
s.False(s.blueprint.HasCommand(constants.CommandCreate))
s.blueprint.Create()
Expand Down
63 changes: 40 additions & 23 deletions database/schema/grammars/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ type Postgres struct {
attributeCommands []string
modifiers []func(schema.Blueprint, schema.ColumnDefinition) string
serials []string
wrap *Wrap
}

func NewPostgres() *Postgres {
func NewPostgres(tablePrefix string) *Postgres {
postgres := &Postgres{
attributeCommands: []string{constants.CommandComment},
serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"},
wrap: NewWrap(tablePrefix),
}
postgres.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{
postgres.ModifyDefault,
Expand All @@ -30,40 +32,40 @@ func NewPostgres() *Postgres {
}

func (r *Postgres) CompileAdd(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("alter table %s add column %s", blueprint.GetTableName(), getColumn(r, blueprint, command.Column))
return fmt.Sprintf("alter table %s add column %s", r.wrap.Table(blueprint.GetTableName()), r.getColumn(blueprint, command.Column))
}

func (r *Postgres) CompileCreate(blueprint schema.Blueprint) string {
return fmt.Sprintf("create table %s (%s)", blueprint.GetTableName(), strings.Join(getColumns(r, blueprint), ","))
return fmt.Sprintf("create table %s (%s)", r.wrap.Table(blueprint.GetTableName()), strings.Join(r.getColumns(blueprint), ", "))
}

func (r *Postgres) CompileDropAllDomains(domains []string) string {
return fmt.Sprintf("drop domain %s cascade", strings.Join(domains, ", "))
return fmt.Sprintf("drop domain %s cascade", strings.Join(r.EscapeNames(domains), ", "))
}

func (r *Postgres) CompileDropAllTables(tables []string) string {
return fmt.Sprintf("drop table %s cascade", strings.Join(tables, ", "))
return fmt.Sprintf("drop table %s cascade", strings.Join(r.EscapeNames(tables), ", "))
}

func (r *Postgres) CompileDropAllTypes(types []string) string {
return fmt.Sprintf("drop type %s cascade", strings.Join(types, ", "))
return fmt.Sprintf("drop type %s cascade", strings.Join(r.EscapeNames(types), ", "))
}

func (r *Postgres) CompileDropAllViews(views []string) string {
return fmt.Sprintf("drop view %s cascade", strings.Join(views, ", "))
return fmt.Sprintf("drop view %s cascade", strings.Join(r.EscapeNames(views), ", "))
}

func (r *Postgres) CompileDropIfExists(blueprint schema.Blueprint) string {
return fmt.Sprintf("drop table if exists %s", blueprint.GetTableName())
return fmt.Sprintf("drop table if exists %s", r.wrap.Table(blueprint.GetTableName()))
}

func (r *Postgres) CompileForeign(blueprint schema.Blueprint, command *schema.Command) string {
sql := fmt.Sprintf("alter table %s add constraint %s foreign key (%s) references %s (%s)",
blueprint.GetTableName(),
command.Index,
strings.Join(command.Columns, ", "),
fmt.Sprintf("%s%s", blueprint.GetPrefix(), command.On),
strings.Join(command.References, ", "))
r.wrap.Table(blueprint.GetTableName()),
r.wrap.Column(command.Index),
r.wrap.Columns(command.Columns),
r.wrap.Table(command.On),
r.wrap.Columns(command.References))
if command.OnDelete != "" {
sql += " on delete " + command.OnDelete
}
Expand All @@ -81,10 +83,10 @@ func (r *Postgres) CompileIndex(blueprint schema.Blueprint, command *schema.Comm
}

return fmt.Sprintf("create index %s on %s%s (%s)",
command.Index,
blueprint.GetTableName(),
r.wrap.Column(command.Index),
r.wrap.Table(blueprint.GetTableName()),
algorithm,
strings.Join(command.Columns, ", "),
r.wrap.Columns(command.Columns),
)
}

Expand All @@ -101,15 +103,15 @@ func (r *Postgres) CompileIndexes(schema, table string) string {
"left join pg_attribute a on a.attrelid = i.indrelid and a.attnum = indseq.num "+
"where tc.relname = %s and tn.nspname = %s "+
"group by ic.relname, am.amname, i.indisunique, i.indisprimary",
quoteString(table),
quoteString(schema),
r.wrap.Quote(table),
r.wrap.Quote(schema),

Check warning on line 107 in database/schema/grammars/postgres.go

View check run for this annotation

Codecov / codecov/patch

database/schema/grammars/postgres.go#L106-L107

Added lines #L106 - L107 were not covered by tests
)

return query
}

func (r *Postgres) CompilePrimary(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("alter table %s add primary key (%s)", blueprint.GetTableName(), strings.Join(command.Columns, ","))
return fmt.Sprintf("alter table %s add primary key (%s)", r.wrap.Table(blueprint.GetTableName()), r.wrap.Columns(command.Columns))
}

func (r *Postgres) CompileTables() string {
Expand Down Expand Up @@ -155,10 +157,6 @@ func (r *Postgres) GetAttributeCommands() []string {
return r.attributeCommands
}

func (r *Postgres) GetModifiers() []func(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
return r.modifiers
}

func (r *Postgres) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetDefault() != nil {
return fmt.Sprintf(" default %s", getDefaultValue(column.GetDefault()))
Expand Down Expand Up @@ -207,3 +205,22 @@ func (r *Postgres) TypeString(column schema.ColumnDefinition) string {

return "varchar"
}

func (r *Postgres) getColumns(blueprint schema.Blueprint) []string {
var columns []string
for _, column := range blueprint.GetAddedColumns() {
columns = append(columns, r.getColumn(blueprint, column))
}

return columns
}

func (r *Postgres) getColumn(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
sql := fmt.Sprintf("%s %s", r.wrap.Column(column.GetName()), getType(r, column))

for _, modifier := range r.modifiers {
sql += modifier(blueprint, column)
}

return sql
}
88 changes: 77 additions & 11 deletions database/schema/grammars/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestPostgresSuite(t *testing.T) {
}

func (s *PostgresSuite) SetupTest() {
s.grammar = NewPostgres()
s.grammar = NewPostgres("goravel_")
}

func (s *PostgresSuite) TestCompileAdd() {
Expand All @@ -38,7 +38,7 @@ func (s *PostgresSuite) TestCompileAdd() {
Column: mockColumn,
})

s.Equal("alter table users add column name varchar(1) default 'goravel' not null", sql)
s.Equal(`alter table "goravel_users" add column "name" varchar(1) default 'goravel' not null`, sql)
}

func (s *PostgresSuite) TestCompileCreate() {
Expand Down Expand Up @@ -81,23 +81,38 @@ func (s *PostgresSuite) TestCompileCreate() {
// postgres.go::ModifyNullable
mockColumn2.EXPECT().GetNullable().Return(true).Once()

s.Equal("create table users (id serial primary key not null,name varchar(100) null)",
s.Equal(`create table "goravel_users" ("id" serial primary key not null, "name" varchar(100) null)`,
s.grammar.CompileCreate(mockBlueprint))
}

func (s *PostgresSuite) TestCompileDropAllDomains() {
s.Equal(`drop domain "domain", "user"."email" cascade`, s.grammar.CompileDropAllDomains([]string{"domain", "user.email"}))
}

func (s *PostgresSuite) TestCompileDropAllTables() {
s.Equal(`drop table "domain", "user"."email" cascade`, s.grammar.CompileDropAllTables([]string{"domain", "user.email"}))
}

func (s *PostgresSuite) TestCompileDropAllTypes() {
s.Equal(`drop type "domain", "user"."email" cascade`, s.grammar.CompileDropAllTypes([]string{"domain", "user.email"}))
}

func (s *PostgresSuite) TestCompileDropAllViews() {
s.Equal(`drop view "domain", "user"."email" cascade`, s.grammar.CompileDropAllViews([]string{"domain", "user.email"}))
}

func (s *PostgresSuite) TestCompileDropIfExists() {
mockBlueprint := mocksschema.NewBlueprint(s.T())
mockBlueprint.EXPECT().GetTableName().Return("users").Once()

s.Equal("drop table if exists users", s.grammar.CompileDropIfExists(mockBlueprint))
s.Equal(`drop table if exists "goravel_users"`, s.grammar.CompileDropIfExists(mockBlueprint))
}

func (s *PostgresSuite) TestCompileForeign() {
var mockBlueprint *mocksschema.Blueprint

beforeEach := func() {
mockBlueprint = mocksschema.NewBlueprint(s.T())
mockBlueprint.EXPECT().GetPrefix().Return("goravel_").Once()
mockBlueprint.EXPECT().GetTableName().Return("users").Once()
}

Expand All @@ -110,23 +125,23 @@ func (s *PostgresSuite) TestCompileForeign() {
name: "with on delete and on update",
command: &contractsschema.Command{
Index: "fk_users_role_id",
Columns: []string{"role_id"},
Columns: []string{"role_id", "user_id"},
On: "roles",
References: []string{"id"},
References: []string{"id", "user_id"},
OnDelete: "cascade",
OnUpdate: "restrict",
},
expectSql: "alter table users add constraint fk_users_role_id foreign key (role_id) references goravel_roles (id) on delete cascade on update restrict",
expectSql: `alter table "goravel_users" add constraint "fk_users_role_id" foreign key ("role_id", "user_id") references "goravel_roles" ("id", "user_id") on delete cascade on update restrict`,
},
{
name: "without on delete and on update",
command: &contractsschema.Command{
Index: "fk_users_role_id",
Columns: []string{"role_id"},
Columns: []string{"role_id", "user_id"},
On: "roles",
References: []string{"id"},
References: []string{"id", "user_id"},
},
expectSql: "alter table users add constraint fk_users_role_id foreign key (role_id) references goravel_roles (id)",
expectSql: `alter table "goravel_users" add constraint "fk_users_role_id" foreign key ("role_id", "user_id") references "goravel_roles" ("id", "user_id")`,
},
}

Expand All @@ -140,6 +155,57 @@ func (s *PostgresSuite) TestCompileForeign() {
}
}

func (s *PostgresSuite) TestCompileIndex() {
var mockBlueprint *mocksschema.Blueprint

beforeEach := func() {
mockBlueprint = mocksschema.NewBlueprint(s.T())
mockBlueprint.EXPECT().GetTableName().Return("users").Once()
}

tests := []struct {
name string
command *contractsschema.Command
expectSql string
}{
{
name: "with Algorithm",
command: &contractsschema.Command{
Index: "fk_users_role_id",
Columns: []string{"role_id", "user_id"},
Algorithm: "btree",
},
expectSql: `create index "fk_users_role_id" on "goravel_users" using btree ("role_id", "user_id")`,
},
{
name: "without Algorithm",
command: &contractsschema.Command{
Index: "fk_users_role_id",
Columns: []string{"role_id", "user_id"},
},
expectSql: `create index "fk_users_role_id" on "goravel_users" ("role_id", "user_id")`,
},
}

for _, test := range tests {
s.Run(test.name, func() {
beforeEach()

sql := s.grammar.CompileIndex(mockBlueprint, test.command)
s.Equal(test.expectSql, sql)
})
}
}

func (s *PostgresSuite) TestCompilePrimary() {
mockBlueprint := mocksschema.NewBlueprint(s.T())
mockBlueprint.EXPECT().GetTableName().Return("users").Once()

s.Equal(`alter table "goravel_users" add primary key ("role_id", "user_id")`, s.grammar.CompilePrimary(mockBlueprint, &contractsschema.Command{
Columns: []string{"role_id", "user_id"},
}))
}

func (s *PostgresSuite) TestEscapeNames() {
// SingleName
names := []string{"username"}
Expand Down
Loading

0 comments on commit e4e70d3

Please sign in to comment.