139 lines
2.7 KiB
Go
139 lines
2.7 KiB
Go
package db
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net/url"
|
|
"os"
|
|
|
|
"github.com/joho/godotenv"
|
|
_ "github.com/lib/pq"
|
|
"xorm.io/xorm"
|
|
)
|
|
|
|
type UserTokens struct {
|
|
UserID string `xorm:"pk not null 'user_id'"`
|
|
RefreshToken string `xorm:"not null 'refresh_token'"`
|
|
}
|
|
|
|
func (UserTokens) TableName() string {
|
|
return "user_tokens"
|
|
}
|
|
|
|
var (
|
|
DB_PORT string
|
|
DB_HOST string
|
|
DB_USER string
|
|
DB_PASSWORD string
|
|
DB_NAME string
|
|
engine *xorm.Engine
|
|
)
|
|
|
|
func init() {
|
|
var err error = godotenv.Load()
|
|
if err != nil {
|
|
log.Fatal("Error loading .env file")
|
|
}
|
|
|
|
DB_PORT = url.PathEscape(os.Getenv("DB_PORT"))
|
|
DB_HOST = url.PathEscape(os.Getenv("DB_HOST"))
|
|
DB_USER = url.PathEscape(os.Getenv("DB_USER"))
|
|
DB_PASSWORD = url.PathEscape(os.Getenv("DB_PASSWORD"))
|
|
DB_NAME = url.PathEscape(os.Getenv("DB_NAME"))
|
|
|
|
engine, err = xorm.NewEngine(
|
|
"postgres",
|
|
fmt.Sprintf(
|
|
"host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
|
DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME,
|
|
),
|
|
)
|
|
|
|
if err != nil {
|
|
log.Fatal("Error loading database: ", err)
|
|
}
|
|
|
|
CheckEngine()
|
|
SyncTables()
|
|
|
|
fmt.Println("DB initialized")
|
|
}
|
|
|
|
func CheckEngine() {
|
|
err := engine.Ping()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func SyncTables() {
|
|
doesExist, err := engine.IsTableExist("user_tokens")
|
|
if err != nil {
|
|
log.Fatal("Error checking if table exists")
|
|
}
|
|
if !doesExist {
|
|
fmt.Println("Table doesn't exist, creating")
|
|
err = engine.Sync(new(UserTokens))
|
|
if err != nil {
|
|
log.Fatal("Error creating tables")
|
|
}
|
|
}
|
|
}
|
|
|
|
func GetRefreshToken(userID string) (refreshToken string, err error) {
|
|
u := new(UserTokens)
|
|
u.UserID = userID
|
|
has, err := engine.Get(u)
|
|
if err != nil {
|
|
log.Fatal("Error checking if user exists: ", err)
|
|
}
|
|
if has {
|
|
return u.RefreshToken, nil
|
|
} else {
|
|
return "", fmt.Errorf("user does not exist")
|
|
}
|
|
}
|
|
|
|
func Insert(userID string, refreshToken string) (err error) {
|
|
u := new(UserTokens)
|
|
u.UserID = userID
|
|
u.RefreshToken = refreshToken
|
|
_, err = engine.Insert(u)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func Update(userID string, refreshToken string) (err error) {
|
|
u := new(UserTokens)
|
|
u.UserID = userID
|
|
u.RefreshToken = refreshToken
|
|
_, err = engine.ID(userID).Update(u)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func SetRefreshToken(userID string, refreshToken string) (err error) {
|
|
var userTokens = new(UserTokens)
|
|
userTokens.UserID = userID
|
|
if a, _ := GetRefreshToken(userID); a != "" {
|
|
Update(userID, refreshToken)
|
|
} else {
|
|
Insert(userID, refreshToken)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func HasRefreshToken(userId string) (has bool) {
|
|
u := new(UserTokens)
|
|
u.UserID = userId
|
|
has, err := engine.Get(u)
|
|
if err != nil {
|
|
log.Fatal("Error checking if user exists: ", err)
|
|
}
|
|
return has
|
|
}
|