From 12e32b58c1bc27a482556aa0884798ba8d2d391b Mon Sep 17 00:00:00 2001 From: Billy Olsen Date: Mon, 14 Aug 2023 18:48:08 -0700 Subject: [PATCH] Update the schema to rename groups to ldapgroups Update the schema to rename groups to ldapgroups. This allows for the table name and means to access it to be consistent across all databases. Signed-off-by: Billy Olsen --- sqlite.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/sqlite.go b/sqlite.go index 4fb17f9..ebf368d 100644 --- a/sqlite.go +++ b/sqlite.go @@ -2,6 +2,7 @@ package main import ( "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" @@ -50,9 +51,9 @@ CREATE TABLE IF NOT EXISTS users ( statement.Exec() statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_user_name on users(name)") statement.Exec() - statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS groups (id INTEGER PRIMARY KEY, name TEXT NOT NULL, gidnumber INTEGER NOT NULL)") + statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS ldapgroups (id INTEGER PRIMARY KEY, name TEXT NOT NULL, gidnumber INTEGER NOT NULL)") statement.Exec() - statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_group_name on groups(name)") + statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_group_name on ldapgroups(name)") statement.Exec() statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS includegroups (id INTEGER PRIMARY KEY, parentgroupid INTEGER NOT NULL, includegroupid INTEGER NOT NULL)") statement.Exec() @@ -66,6 +67,26 @@ func (b SqliteBackend) MigrateSchema(db *sql.DB, checker func(*sql.DB, string) b statement, _ := db.Prepare("ALTER TABLE users ADD COLUMN sshkeys TEXT DEFAULT ''") statement.Exec() } + + if TableExists(db, "groups") { + // Drop the table created during schema creation + statement, _ := db.Prepare("DROP TABLE ldapgroups") + statement.Exec() + + statement, _ = db.Prepare("ALTER TABLE groups RENAME TO ldapgroups") + statement.Exec() + } +} + +// Indicates whether the table exists or not +func TableExists(db *sql.DB, tableName string) bool { + var found string + err := db.QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s", tableName)).Scan( + &found) + if err != nil { + return false + } + return true } func main() {}