package dao import ( "context" "fmt" "math/rand" "time" "git.x2erp.com/qdy/go-svc-configure/internal/tables" "github.com/jmoiron/sqlx" ) const ( invitationCodeLength = 16 invitationCodeChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" defaultExpiresDays = 7 ) // GenerateInvitationCode 生成唯一邀请码 func GenerateInvitationCode() string { rand.Seed(time.Now().UnixNano()) b := make([]byte, invitationCodeLength) for i := range b { b[i] = invitationCodeChars[rand.Intn(len(invitationCodeChars))] } return string(b) } // CreateInvitationCode 创建邀请码 func CreateInvitationCode(ctx context.Context, tx *sqlx.Tx, tenantID, roleID, creator string, expiresDays int) (string, error) { if expiresDays <= 0 { expiresDays = defaultExpiresDays } // 生成唯一邀请码 var code string var exists bool maxAttempts := 10 for i := 0; i < maxAttempts; i++ { code = GenerateInvitationCode() // 检查是否已存在 var count int checkQuery := `SELECT COUNT(*) FROM config_invitation_code WHERE code = ?` err := tx.GetContext(ctx, &count, checkQuery, code) if err != nil { return "", fmt.Errorf("检查邀请码存在性失败: %v", err) } if count == 0 { exists = false break } exists = true } if exists { return "", fmt.Errorf("生成唯一邀请码失败,请重试") } // 计算过期时间 expiresAt := time.Now().Add(time.Duration(expiresDays) * 24 * time.Hour) id := fmt.Sprintf("invitation.%s.%s.%d", tenantID, roleID, time.Now().Unix()) query := ` INSERT INTO config_invitation_code (id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at) VALUES (?, ?, ?, ?, ?, 0, '', ?, CURRENT_TIMESTAMP) ` result, err := tx.ExecContext(ctx, query, id, code, tenantID, roleID, expiresAt, creator) if err != nil { return "", fmt.Errorf("创建邀请码失败: %v", err) } rowsAffected, err := ValidateResultRowsAffected(result, err, 1) if err != nil { return "", err } if rowsAffected != 1 { return "", fmt.Errorf("创建邀请码失败,影响行数: %d", rowsAffected) } return code, nil } // ValidateInvitationCode 验证邀请码有效性 func ValidateInvitationCode(ctx context.Context, db *sqlx.DB, code string) (*tables.InvitationCodeDB, error) { var invitation tables.InvitationCodeDB query := ` SELECT id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at FROM config_invitation_code WHERE code = ? ` err := db.GetContext(ctx, &invitation, query, code) if err != nil { return nil, fmt.Errorf("邀请码不存在: %v", err) } // 检查是否已使用 if invitation.Used != 0 { return nil, fmt.Errorf("邀请码已使用") } // 检查是否过期 if time.Now().After(invitation.ExpiresAt) { return nil, fmt.Errorf("邀请码已过期") } return &invitation, nil } // MarkInvitationCodeUsed 标记邀请码已使用 func MarkInvitationCodeUsed(ctx context.Context, tx *sqlx.Tx, code, usedBy string) (int64, error) { query := ` UPDATE config_invitation_code SET used = 1, used_by = ? WHERE code = ? AND used = 0 ` result, err := tx.ExecContext(ctx, query, usedBy, code) if err != nil { return -1, fmt.Errorf("标记邀请码已使用失败: %v", err) } return ValidateResultRowsAffected(result, err, 1) } // GetInvitationCodeByCode 根据邀请码查询详情 func GetInvitationCodeByCode(ctx context.Context, db *sqlx.DB, code string) (*tables.InvitationCodeDB, error) { var invitation tables.InvitationCodeDB query := ` SELECT id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at FROM config_invitation_code WHERE code = ? ` err := db.GetContext(ctx, &invitation, query, code) if err != nil { return nil, fmt.Errorf("查询邀请码失败: %v", err) } return &invitation, nil } // ListInvitationCodesByTenant 按租户查询邀请码列表 func ListInvitationCodesByTenant(ctx context.Context, db *sqlx.DB, tenantID string, includeUsed bool) ([]tables.InvitationCodeDB, error) { var invitations []tables.InvitationCodeDB query := ` SELECT id, code, tenant_id, role_id, expires_at, used, used_by, creator, created_at FROM config_invitation_code WHERE tenant_id = ? ` if !includeUsed { query += " AND used = 0" } query += " ORDER BY created_at DESC" err := db.SelectContext(ctx, &invitations, query, tenantID) if err != nil { return nil, fmt.Errorf("查询邀请码列表失败: %v", err) } return invitations, nil } // DeleteInvitationCode 删除邀请码(仅限未使用的) func DeleteInvitationCode(ctx context.Context, tx *sqlx.Tx, code string) (int64, error) { query := `DELETE FROM config_invitation_code WHERE code = ? AND used = 0` result, err := tx.ExecContext(ctx, query, code) if err != nil { return -1, fmt.Errorf("删除邀请码失败: %v", err) } return ValidateResultRowsAffected(result, err, 1) } // CountInvitationCodesByTenant 统计租户邀请码数量 func CountInvitationCodesByTenant(ctx context.Context, db *sqlx.DB, tenantID string) (int, error) { var count int query := `SELECT COUNT(*) FROM config_invitation_code WHERE tenant_id = ?` err := db.GetContext(ctx, &count, query, tenantID) if err != nil { return 0, fmt.Errorf("统计邀请码数量失败: %v", err) } return count, nil }