Нет описания
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

table_definition.go 9.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. package sqldef
  2. import (
  3. "fmt"
  4. "strings"
  5. "sync"
  6. "git.x2erp.com/qdy/go-db/sqldef/generators"
  7. )
  8. // 数据类型常量
  9. const (
  10. TypeVarchar = "VARCHAR"
  11. TypeChar = "CHAR"
  12. TypeText = "TEXT"
  13. TypeInt = "INT"
  14. TypeBigInt = "BIGINT"
  15. TypeTinyInt = "TINYINT"
  16. TypeDecimal = "DECIMAL"
  17. TypeFloat = "FLOAT"
  18. TypeDouble = "DOUBLE"
  19. TypeBool = "BOOL"
  20. TypeDateTime = "DATETIME"
  21. TypeTimestamp = "TIMESTAMP"
  22. TypeDate = "DATE"
  23. TypeTime = "TIME"
  24. TypeBlob = "BLOB"
  25. TypeJson = "JSON"
  26. )
  27. // Column 列定义
  28. type Column struct {
  29. name string
  30. sqlType string
  31. comment string
  32. options []string
  33. defaultValue string // 新增:默认值
  34. }
  35. // ColumnBuilder 列构建器
  36. type ColumnBuilder struct {
  37. table *TableBuilder
  38. column Column
  39. }
  40. // TableBuilder 表构建器
  41. type TableBuilder struct {
  42. name string
  43. comment string
  44. columns []Column
  45. indexes []string
  46. }
  47. // Registry 注册表(懒加载)
  48. type Registry struct {
  49. tables map[string]generators.TableDDL
  50. mu sync.RWMutex
  51. once sync.Once
  52. regFns []func(*Registry)
  53. }
  54. var globalRegistry = &Registry{
  55. tables: make(map[string]generators.TableDDL),
  56. }
  57. // AddRegistration 添加注册函数
  58. func AddRegistration(fn func(*Registry)) {
  59. globalRegistry.regFns = append(globalRegistry.regFns, fn)
  60. }
  61. func (r *Registry) lazyInit() {
  62. r.mu.Lock()
  63. defer r.mu.Unlock()
  64. for _, fn := range r.regFns {
  65. fn(r)
  66. }
  67. }
  68. func (r *Registry) ensureInit() {
  69. r.once.Do(r.lazyInit)
  70. }
  71. // RegisterTable 注册表
  72. func (r *Registry) RegisterTable(table generators.TableDDL) {
  73. if table.Name != "" {
  74. r.tables[table.Name] = table
  75. }
  76. }
  77. // GetAll 获取所有表定义
  78. func GetAll() []generators.TableDDL {
  79. globalRegistry.ensureInit()
  80. globalRegistry.mu.RLock()
  81. defer globalRegistry.mu.RUnlock()
  82. result := make([]generators.TableDDL, 0, len(globalRegistry.tables))
  83. for _, table := range globalRegistry.tables {
  84. result = append(result, table)
  85. }
  86. return result
  87. }
  88. // Get 获取表定义
  89. func Get(tableName string) (generators.TableDDL, bool) {
  90. globalRegistry.ensureInit()
  91. globalRegistry.mu.RLock()
  92. defer globalRegistry.mu.RUnlock()
  93. table, exists := globalRegistry.tables[tableName]
  94. return table, exists
  95. }
  96. // ================== TableBuilder 表构建器 ==================
  97. // NewTable 创建新表
  98. func NewTable(name string, comment ...string) *TableBuilder {
  99. tb := &TableBuilder{name: name}
  100. if len(comment) > 0 {
  101. tb.comment = comment[0]
  102. }
  103. return tb
  104. }
  105. // ================== 列定义方法 ==================
  106. // 默认64位长度
  107. func (t *TableBuilder) ID(name string, length ...int) *ColumnBuilder {
  108. size := 64
  109. if len(length) > 0 {
  110. size = length[0]
  111. }
  112. return t.column(name, TypeVarchar, size).PrimaryKey()
  113. }
  114. func (t *TableBuilder) String(name string, length int) *ColumnBuilder {
  115. return t.column(name, TypeVarchar, length)
  116. }
  117. func (t *TableBuilder) Char(name string, length int) *ColumnBuilder {
  118. return t.column(name, TypeChar, length)
  119. }
  120. func (t *TableBuilder) Text(name string) *ColumnBuilder {
  121. return t.column(name, TypeText, 0)
  122. }
  123. func (t *TableBuilder) Int(name string) *ColumnBuilder {
  124. return t.column(name, TypeInt, 0)
  125. }
  126. func (t *TableBuilder) TinyInt(name string) *ColumnBuilder {
  127. return t.column(name, TypeTinyInt, 0)
  128. }
  129. func (t *TableBuilder) BigInt(name string) *ColumnBuilder {
  130. return t.column(name, TypeBigInt, 0)
  131. }
  132. func (t *TableBuilder) Bool(name string) *ColumnBuilder {
  133. return t.column(name, TypeBool, 0)
  134. }
  135. func (t *TableBuilder) JSON(name string) *ColumnBuilder {
  136. return t.column(name, TypeJson, 0)
  137. }
  138. func (t *TableBuilder) Decimal(name string, precision, scale int) *ColumnBuilder {
  139. col := &ColumnBuilder{
  140. table: t,
  141. column: Column{
  142. name: name,
  143. sqlType: fmt.Sprintf("%s(%d,%d)", TypeDecimal, precision, scale),
  144. },
  145. }
  146. return col
  147. }
  148. func (t *TableBuilder) Float(name string) *ColumnBuilder {
  149. return t.column(name, TypeFloat, 0)
  150. }
  151. func (t *TableBuilder) Double(name string) *ColumnBuilder {
  152. return t.column(name, TypeDouble, 0)
  153. }
  154. func (t *TableBuilder) DateTime(name string) *ColumnBuilder {
  155. return t.column(name, TypeDateTime, 0)
  156. }
  157. func (t *TableBuilder) Timestamp(name string) *ColumnBuilder {
  158. return t.column(name, TypeTimestamp, 0)
  159. }
  160. func (t *TableBuilder) Time(name string) *ColumnBuilder {
  161. return t.column(name, TypeTime, 0)
  162. }
  163. func (t *TableBuilder) Date(name string) *ColumnBuilder {
  164. return t.column(name, TypeDate, 0)
  165. }
  166. func (t *TableBuilder) Blob(name string) *ColumnBuilder {
  167. return t.column(name, TypeBlob, 0)
  168. }
  169. // 私有辅助方法
  170. func (t *TableBuilder) column(name string, dataType string, length int) *ColumnBuilder {
  171. sqlType := dataType
  172. if length > 0 {
  173. sqlType = fmt.Sprintf("%s(%d)", dataType, length)
  174. }
  175. col := &ColumnBuilder{
  176. table: t,
  177. column: Column{
  178. name: name,
  179. sqlType: sqlType,
  180. },
  181. }
  182. return col
  183. }
  184. // ================== 列构建器方法 ==================
  185. func (c *ColumnBuilder) NotNull() *ColumnBuilder {
  186. c.column.options = append(c.column.options, "NOT NULL")
  187. return c
  188. }
  189. func (c *ColumnBuilder) Default(value string) *ColumnBuilder {
  190. c.column.defaultValue = value
  191. return c
  192. }
  193. func (c *ColumnBuilder) PrimaryKey() *ColumnBuilder {
  194. c.column.options = append(c.column.options, "PRIMARY KEY")
  195. return c
  196. }
  197. func (c *ColumnBuilder) AutoIncrement() *ColumnBuilder {
  198. c.column.options = append(c.column.options, "AUTO_INCREMENT")
  199. return c
  200. }
  201. func (c *ColumnBuilder) Unique() *ColumnBuilder {
  202. c.column.options = append(c.column.options, "UNIQUE")
  203. return c
  204. }
  205. func (c *ColumnBuilder) Comment(comment string) *ColumnBuilder {
  206. c.column.comment = comment
  207. return c
  208. }
  209. // End 结束列定义,返回TableBuilder继续定义其他列
  210. func (c *ColumnBuilder) End() *TableBuilder {
  211. c.table.columns = append(c.table.columns, c.column)
  212. return c.table
  213. }
  214. // ================== 索引方法 ==================
  215. func (t *TableBuilder) AddIndex(name string, columns ...string) *TableBuilder {
  216. idx := fmt.Sprintf("INDEX %s (%s)", name, strings.Join(columns, ", "))
  217. t.indexes = append(t.indexes, idx)
  218. return t
  219. }
  220. func (t *TableBuilder) AddUniqueIndex(name string, columns ...string) *TableBuilder {
  221. idx := fmt.Sprintf("UNIQUE INDEX %s (%s)", name, strings.Join(columns, ", "))
  222. t.indexes = append(t.indexes, idx)
  223. return t
  224. }
  225. // Build 构建表定义,包含完整的Schema信息
  226. func (t *TableBuilder) Build() generators.TableDDL {
  227. // 构建列Schema
  228. columns := make([]generators.ColumnSchema, 0, len(t.columns))
  229. for _, col := range t.columns {
  230. // 解析列类型
  231. colType, length, precision, scale := parseColumnType(col.sqlType)
  232. // 提取默认值
  233. var defaultValue string
  234. var options []string
  235. for _, opt := range col.options {
  236. if strings.HasPrefix(opt, "DEFAULT ") {
  237. defaultValue = strings.TrimPrefix(opt, "DEFAULT ")
  238. } else {
  239. options = append(options, opt)
  240. }
  241. }
  242. columns = append(columns, generators.ColumnSchema{
  243. Name: col.name,
  244. Type: colType,
  245. Length: length,
  246. Precision: precision,
  247. Scale: scale,
  248. Comment: col.comment,
  249. Options: options,
  250. Default: defaultValue,
  251. })
  252. }
  253. // 构建索引Schema
  254. indexes := make([]generators.IndexSchema, 0, len(t.indexes))
  255. for _, idx := range t.indexes {
  256. // 解析索引字符串,例如: "INDEX idx_name (col1, col2)" 或 "UNIQUE INDEX idx_name (col1, col2)"
  257. indexName, isUnique, columns := parseIndex(idx)
  258. if indexName != "" && len(columns) > 0 {
  259. indexes = append(indexes, generators.IndexSchema{
  260. Name: indexName,
  261. Columns: columns,
  262. Unique: isUnique,
  263. })
  264. }
  265. }
  266. return generators.TableDDL{
  267. Name: t.name,
  268. Comment: t.comment,
  269. Schema: &generators.TableSchema{
  270. Name: t.name,
  271. Comment: t.comment,
  272. Columns: columns,
  273. Indexes: indexes,
  274. },
  275. }
  276. }
  277. // parseColumnType 解析列类型字符串
  278. func parseColumnType(sqlType string) (colType string, length, precision, scale int) {
  279. sqlType = strings.ToUpper(strings.TrimSpace(sqlType))
  280. // 处理带括号的类型,如 VARCHAR(255), DECIMAL(10,2)
  281. if strings.Contains(sqlType, "(") {
  282. openParen := strings.Index(sqlType, "(")
  283. closeParen := strings.Index(sqlType, ")")
  284. colType = sqlType[:openParen]
  285. params := strings.TrimSpace(sqlType[openParen+1 : closeParen])
  286. switch colType {
  287. case "DECIMAL":
  288. if strings.Contains(params, ",") {
  289. parts := strings.Split(params, ",")
  290. if len(parts) == 2 {
  291. fmt.Sscanf(parts[0], "%d", &precision)
  292. fmt.Sscanf(parts[1], "%d", &scale)
  293. }
  294. }
  295. case "VARCHAR", "CHAR":
  296. fmt.Sscanf(params, "%d", &length)
  297. default:
  298. fmt.Sscanf(params, "%d", &length)
  299. }
  300. } else {
  301. colType = sqlType
  302. }
  303. return
  304. }
  305. // parseIndex 解析索引字符串
  306. func parseIndex(indexStr string) (name string, isUnique bool, columns []string) {
  307. indexStr = strings.TrimSpace(indexStr)
  308. // 检查是否是唯一索引
  309. if strings.HasPrefix(indexStr, "UNIQUE INDEX") {
  310. isUnique = true
  311. indexStr = strings.TrimPrefix(indexStr, "UNIQUE INDEX ")
  312. } else if strings.HasPrefix(indexStr, "INDEX") {
  313. indexStr = strings.TrimPrefix(indexStr, "INDEX ")
  314. } else {
  315. return "", false, nil
  316. }
  317. // 分割索引名和列
  318. parts := strings.Split(indexStr, " ")
  319. if len(parts) < 2 {
  320. return "", false, nil
  321. }
  322. name = parts[0]
  323. // 提取列,如 (col1, col2)
  324. colsPart := strings.Join(parts[1:], " ")
  325. colsPart = strings.Trim(colsPart, "()")
  326. columns = strings.Split(colsPart, ", ")
  327. return
  328. }
  329. // Register 快捷方法:直接注册表
  330. func (t *TableBuilder) Register() {
  331. table := t.Build()
  332. AddRegistration(func(r *Registry) {
  333. r.RegisterTable(table)
  334. })
  335. }