Sin descripción
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

invitation_dao.go 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package dao
  2. import (
  3. "context"
  4. "fmt"
  5. "math/rand"
  6. "time"
  7. "git.x2erp.com/qdy/go-svc-configure/internal/tables"
  8. "github.com/jmoiron/sqlx"
  9. )
  10. const (
  11. invitationCodeLength = 16
  12. invitationCodeChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
  13. defaultExpiresDays = 7
  14. )
  15. // GenerateInvitationCode 生成唯一邀请码
  16. func GenerateInvitationCode() string {
  17. rand.Seed(time.Now().UnixNano())
  18. b := make([]byte, invitationCodeLength)
  19. for i := range b {
  20. b[i] = invitationCodeChars[rand.Intn(len(invitationCodeChars))]
  21. }
  22. return string(b)
  23. }
  24. // CreateInvitationCode 创建邀请码
  25. func CreateInvitationCode(ctx context.Context, tx *sqlx.Tx, tenantID, roleID, creator string, expiresDays int) (string, error) {
  26. if expiresDays <= 0 {
  27. expiresDays = defaultExpiresDays
  28. }
  29. // 生成唯一邀请码
  30. var code string
  31. var exists bool
  32. maxAttempts := 10
  33. for i := 0; i < maxAttempts; i++ {
  34. code = GenerateInvitationCode()
  35. // 检查是否已存在
  36. var count int
  37. checkQuery := `SELECT COUNT(*) FROM config_invitation_code WHERE code = ?`
  38. err := tx.GetContext(ctx, &count, checkQuery, code)
  39. if err != nil {
  40. return "", fmt.Errorf("检查邀请码存在性失败: %v", err)
  41. }
  42. if count == 0 {
  43. exists = false
  44. break
  45. }
  46. exists = true
  47. }
  48. if exists {
  49. return "", fmt.Errorf("生成唯一邀请码失败,请重试")
  50. }
  51. // 计算过期时间
  52. expiresAt := time.Now().Add(time.Duration(expiresDays) * 24 * time.Hour)
  53. id := fmt.Sprintf("invitation.%s.%s.%d", tenantID, roleID, time.Now().Unix())
  54. query := `
  55. INSERT INTO config_invitation_code
  56. (id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at)
  57. VALUES (?, ?, ?, ?, ?, 0, '', ?, CURRENT_TIMESTAMP)
  58. `
  59. result, err := tx.ExecContext(ctx, query, id, code, tenantID, roleID, expiresAt, creator)
  60. if err != nil {
  61. return "", fmt.Errorf("创建邀请码失败: %v", err)
  62. }
  63. rowsAffected, err := ValidateResultRowsAffected(result, err, 1)
  64. if err != nil {
  65. return "", err
  66. }
  67. if rowsAffected != 1 {
  68. return "", fmt.Errorf("创建邀请码失败,影响行数: %d", rowsAffected)
  69. }
  70. return code, nil
  71. }
  72. // ValidateInvitationCode 验证邀请码有效性
  73. func ValidateInvitationCode(ctx context.Context, db *sqlx.DB, code string) (*tables.InvitationCodeDB, error) {
  74. var invitation tables.InvitationCodeDB
  75. query := `
  76. SELECT id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at
  77. FROM config_invitation_code
  78. WHERE code = ?
  79. `
  80. err := db.GetContext(ctx, &invitation, query, code)
  81. if err != nil {
  82. return nil, fmt.Errorf("邀请码不存在: %v", err)
  83. }
  84. // 检查是否已使用
  85. if invitation.Used != 0 {
  86. return nil, fmt.Errorf("邀请码已使用")
  87. }
  88. // 检查是否过期
  89. if time.Now().After(invitation.ExpiresAt) {
  90. return nil, fmt.Errorf("邀请码已过期")
  91. }
  92. return &invitation, nil
  93. }
  94. // MarkInvitationCodeUsed 标记邀请码已使用
  95. func MarkInvitationCodeUsed(ctx context.Context, tx *sqlx.Tx, code, usedBy string) (int64, error) {
  96. query := `
  97. UPDATE config_invitation_code
  98. SET used = 1, used_by = ?
  99. WHERE code = ? AND used = 0
  100. `
  101. result, err := tx.ExecContext(ctx, query, usedBy, code)
  102. if err != nil {
  103. return -1, fmt.Errorf("标记邀请码已使用失败: %v", err)
  104. }
  105. return ValidateResultRowsAffected(result, err, 1)
  106. }
  107. // GetInvitationCodeByCode 根据邀请码查询详情
  108. func GetInvitationCodeByCode(ctx context.Context, db *sqlx.DB, code string) (*tables.InvitationCodeDB, error) {
  109. var invitation tables.InvitationCodeDB
  110. query := `
  111. SELECT id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at
  112. FROM config_invitation_code
  113. WHERE code = ?
  114. `
  115. err := db.GetContext(ctx, &invitation, query, code)
  116. if err != nil {
  117. return nil, fmt.Errorf("查询邀请码失败: %v", err)
  118. }
  119. return &invitation, nil
  120. }
  121. // ListInvitationCodesByTenant 按租户查询邀请码列表
  122. func ListInvitationCodesByTenant(ctx context.Context, db *sqlx.DB, tenantID string, includeUsed bool) ([]tables.InvitationCodeDB, error) {
  123. var invitations []tables.InvitationCodeDB
  124. query := `
  125. SELECT id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at
  126. FROM config_invitation_code
  127. WHERE tenant_id = ?
  128. `
  129. if !includeUsed {
  130. query += " AND used = 0"
  131. }
  132. query += " ORDER BY created_at DESC"
  133. err := db.SelectContext(ctx, &invitations, query, tenantID)
  134. if err != nil {
  135. return nil, fmt.Errorf("查询邀请码列表失败: %v", err)
  136. }
  137. return invitations, nil
  138. }
  139. // DeleteInvitationCode 删除邀请码(仅限未使用的)
  140. func DeleteInvitationCode(ctx context.Context, tx *sqlx.Tx, code string) (int64, error) {
  141. query := `DELETE FROM config_invitation_code WHERE code = ? AND used = 0`
  142. result, err := tx.ExecContext(ctx, query, code)
  143. if err != nil {
  144. return -1, fmt.Errorf("删除邀请码失败: %v", err)
  145. }
  146. return ValidateResultRowsAffected(result, err, 1)
  147. }
  148. // CountInvitationCodesByTenant 统计租户邀请码数量
  149. func CountInvitationCodesByTenant(ctx context.Context, db *sqlx.DB, tenantID string) (int, error) {
  150. var count int
  151. query := `SELECT COUNT(*) FROM config_invitation_code WHERE tenant_id = ?`
  152. err := db.GetContext(ctx, &count, query, tenantID)
  153. if err != nil {
  154. return 0, fmt.Errorf("统计邀请码数量失败: %v", err)
  155. }
  156. return count, nil
  157. }