package bun

import (
	"context"
	"database/sql"
	"sort"
	"strconv"

	"github.com/uptrace/bun/dialect/feature"
	"github.com/uptrace/bun/dialect/sqltype"
	"github.com/uptrace/bun/internal"
	"github.com/uptrace/bun/schema"
)

type CreateTableQuery struct {
	baseQuery

	temp        bool
	ifNotExists bool
	varchar     int

	fks         []schema.QueryWithArgs
	partitionBy schema.QueryWithArgs
	tablespace  schema.QueryWithArgs
}

var _ Query = (*CreateTableQuery)(nil)

func NewCreateTableQuery(db *DB) *CreateTableQuery {
	q := &CreateTableQuery{
		baseQuery: baseQuery{
			db:   db,
			conn: db.DB,
		},
	}
	return q
}

func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery {
	q.setConn(db)
	return q
}

func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
	q.setTableModel(model)
	return q
}

// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
	for _, table := range tables {
		q.addTable(schema.UnsafeIdent(table))
	}
	return q
}

func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery {
	q.addTable(schema.SafeQuery(query, args))
	return q
}

func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery {
	q.modelTableName = schema.SafeQuery(query, args)
	return q
}

func (q *CreateTableQuery) ColumnExpr(query string, args ...interface{}) *CreateTableQuery {
	q.addColumn(schema.SafeQuery(query, args))
	return q
}

// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Temp() *CreateTableQuery {
	q.temp = true
	return q
}

func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
	q.ifNotExists = true
	return q
}

func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
	q.varchar = n
	return q
}

func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery {
	q.fks = append(q.fks, schema.SafeQuery(query, args))
	return q
}

func (q *CreateTableQuery) PartitionBy(query string, args ...interface{}) *CreateTableQuery {
	q.partitionBy = schema.SafeQuery(query, args)
	return q
}

func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery {
	q.tablespace = schema.UnsafeIdent(tablespace)
	return q
}

func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
	for _, relation := range q.tableModel.Table().Relations {
		if relation.Type == schema.ManyToManyRelation ||
			relation.Type == schema.HasManyRelation {
			continue
		}

		q = q.ForeignKey("(?) REFERENCES ? (?) ? ?",
			Safe(appendColumns(nil, "", relation.BaseFields)),
			relation.JoinTable.SQLName,
			Safe(appendColumns(nil, "", relation.JoinFields)),
			Safe(relation.OnUpdate),
			Safe(relation.OnDelete),
		)
	}
	return q
}

//------------------------------------------------------------------------------

func (q *CreateTableQuery) Operation() string {
	return "CREATE TABLE"
}

func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
	if q.err != nil {
		return nil, q.err
	}
	if q.table == nil {
		return nil, errNilModel
	}

	b = append(b, "CREATE "...)
	if q.temp {
		b = append(b, "TEMP "...)
	}
	b = append(b, "TABLE "...)
	if q.ifNotExists && fmter.Dialect().Features().Has(feature.TableNotExists) {
		b = append(b, "IF NOT EXISTS "...)
	}
	b, err = q.appendFirstTable(fmter, b)
	if err != nil {
		return nil, err
	}

	b = append(b, " ("...)

	for i, field := range q.table.Fields {
		if i > 0 {
			b = append(b, ", "...)
		}

		b = append(b, field.SQLName...)
		b = append(b, " "...)
		b = q.appendSQLType(b, field)
		if field.NotNull {
			b = append(b, " NOT NULL"...)
		}
		if field.AutoIncrement {
			switch {
			case fmter.Dialect().Features().Has(feature.AutoIncrement):
				b = append(b, " AUTO_INCREMENT"...)
			case fmter.Dialect().Features().Has(feature.Identity):
				b = append(b, " IDENTITY"...)
			}
		}
		if field.Identity {
			if fmter.Dialect().Features().Has(feature.GeneratedIdentity) {
				b = append(b, " GENERATED BY DEFAULT AS IDENTITY"...)
			}
		}
		if field.SQLDefault != "" {
			b = append(b, " DEFAULT "...)
			b = append(b, field.SQLDefault...)
		}
	}

	for i, col := range q.columns {
		// Only pre-pend the comma if we are on subsequent iterations, or if there were fields/columns appended before
		// this. This way if we are only appending custom column expressions we will not produce a syntax error with a
		// leading comma.
		if i > 0 || len(q.table.Fields) > 0 {
			b = append(b, ", "...)
		}
		b, err = col.AppendQuery(fmter, b)
		if err != nil {
			return nil, err
		}
	}

	b = q.appendPKConstraint(b, q.table.PKs)
	b = q.appendUniqueConstraints(fmter, b)
	b, err = q.appendFKConstraints(fmter, b)
	if err != nil {
		return nil, err
	}

	b = append(b, ")"...)

	if !q.partitionBy.IsZero() {
		b = append(b, " PARTITION BY "...)
		b, err = q.partitionBy.AppendQuery(fmter, b)
		if err != nil {
			return nil, err
		}
	}

	if !q.tablespace.IsZero() {
		b = append(b, " TABLESPACE "...)
		b, err = q.tablespace.AppendQuery(fmter, b)
		if err != nil {
			return nil, err
		}
	}

	return b, nil
}

func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
	if field.CreateTableSQLType != field.DiscoveredSQLType {
		return append(b, field.CreateTableSQLType...)
	}

	if q.varchar > 0 &&
		field.CreateTableSQLType == sqltype.VarChar {
		b = append(b, "varchar("...)
		b = strconv.AppendInt(b, int64(q.varchar), 10)
		b = append(b, ")"...)
		return b
	}

	return append(b, field.CreateTableSQLType...)
}

func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte {
	unique := q.table.Unique

	keys := make([]string, 0, len(unique))
	for key := range unique {
		keys = append(keys, key)
	}
	sort.Strings(keys)

	for _, key := range keys {
		if key == "" {
			for _, field := range unique[key] {
				b = q.appendUniqueConstraint(fmter, b, key, field)
			}
			continue
		}
		b = q.appendUniqueConstraint(fmter, b, key, unique[key]...)
	}

	return b
}

func (q *CreateTableQuery) appendUniqueConstraint(
	fmter schema.Formatter, b []byte, name string, fields ...*schema.Field,
) []byte {
	if name != "" {
		b = append(b, ", CONSTRAINT "...)
		b = fmter.AppendIdent(b, name)
	} else {
		b = append(b, ","...)
	}
	b = append(b, " UNIQUE ("...)
	b = appendColumns(b, "", fields)
	b = append(b, ")"...)
	return b
}

func (q *CreateTableQuery) appendFKConstraints(
	fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
	for _, fk := range q.fks {
		b = append(b, ", FOREIGN KEY "...)
		b, err = fk.AppendQuery(fmter, b)
		if err != nil {
			return nil, err
		}
	}
	return b, nil
}

func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte {
	if len(pks) == 0 {
		return b
	}

	b = append(b, ", PRIMARY KEY ("...)
	b = appendColumns(b, "", pks)
	b = append(b, ")"...)
	return b
}

// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
	if err := q.beforeCreateTableHook(ctx); err != nil {
		return nil, err
	}

	queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
	if err != nil {
		return nil, err
	}

	query := internal.String(queryBytes)

	res, err := q.exec(ctx, q, query)
	if err != nil {
		return nil, err
	}

	if q.table != nil {
		if err := q.afterCreateTableHook(ctx); err != nil {
			return nil, err
		}
	}

	return res, nil
}

func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error {
	if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok {
		if err := hook.BeforeCreateTable(ctx, q); err != nil {
			return err
		}
	}
	return nil
}

func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error {
	if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok {
		if err := hook.AfterCreateTable(ctx, q); err != nil {
			return err
		}
	}
	return nil
}