Skip to content

Commit 22b10ce

Browse files
authored
Merge pull request #37 from kanmu/feat-conflict
feat: generate INSERT ...ON CONFLICT DO NOTHING
2 parents 454fd9c + 123a833 commit 22b10ce

File tree

3 files changed

+377
-7
lines changed

3 files changed

+377
-7
lines changed

dgw_test.go

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"os"
66
"path/filepath"
7+
"regexp"
78
"testing"
89

910
_ "github.com/lib/pq"
@@ -185,3 +186,297 @@ func TestPgExecuteCustomTemplate(t *testing.T) {
185186
t.Logf("%s", src)
186187
}
187188
}
189+
190+
func TestCreateInsertOnConflictDoNothingSQL(t *testing.T) {
191+
conn, cleanup := testPgSetup(t)
192+
defer cleanup()
193+
194+
structs := testSetupStruct(t, conn)
195+
196+
if len(structs) != 4 {
197+
t.Fatalf("Expected the number of testing structs is 4, got: %d", len(structs))
198+
}
199+
200+
tests := []struct {
201+
tableStruct *Struct
202+
expectSQL string
203+
}{
204+
{
205+
tableStruct: structs[0],
206+
expectSQL: "INSERT INTO t1 (i, str, nullable_str, t_with_tz, t_without_tz, tm) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING RETURNING id",
207+
},
208+
{
209+
tableStruct: structs[1],
210+
expectSQL: "INSERT INTO t2 (i, str, t_with_tz, t_without_tz) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING RETURNING id",
211+
},
212+
{
213+
tableStruct: structs[2],
214+
expectSQL: "INSERT INTO t3 (str, t_with_tz, t_without_tz) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING RETURNING id, i",
215+
},
216+
{
217+
tableStruct: structs[3],
218+
expectSQL: "INSERT INTO t4 (id, i) VALUES ($1, $2) ON CONFLICT DO NOTHING",
219+
},
220+
}
221+
for _, tt := range tests {
222+
t.Run(tt.tableStruct.Table.Name, func(t *testing.T) {
223+
sql := createInsertOnConflictDoNothingSQL(tt.tableStruct)
224+
if sql != tt.expectSQL {
225+
t.Errorf("Expected SQL: %s, got: %s", tt.expectSQL, sql)
226+
}
227+
t.Logf("Table: %s, Generated SQL: %s", tt.tableStruct.Name, sql)
228+
})
229+
}
230+
}
231+
232+
func TestMethodGeneration(t *testing.T) {
233+
conn, cleanup := testPgSetup(t)
234+
defer cleanup()
235+
236+
schema := "public"
237+
tbls, err := PgLoadTableDef(conn, schema)
238+
if err != nil {
239+
t.Fatal(err)
240+
}
241+
242+
if len(tbls) != 4 {
243+
t.Fatalf("Expected the number of testing PgTable is 4, got: %d", len(tbls))
244+
}
245+
246+
tests := []struct {
247+
table *PgTable
248+
expect string
249+
}{
250+
{
251+
table: tbls[0],
252+
expect: `// Create inserts the T1 to the database.
253+
func (r *T1) Create(db Queryer) error {
254+
return r.CreateContext(context.Background(), db)
255+
}
256+
257+
// GetT1ByPk select the T1 from the database.
258+
func GetT1ByPk(db Queryer, pk0 int64) (*T1, error) {
259+
return GetT1ByPkContext(context.Background(), db, pk0)
260+
}
261+
262+
// CreateContext inserts the T1 to the database.
263+
func (r *T1) CreateContext(ctx context.Context, db Queryer) error {
264+
err := db.QueryRowContext(ctx,
265+
` + "`INSERT INTO t1 (i, str, nullable_str, t_with_tz, t_without_tz, tm) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id`" + `,
266+
&r.I, &r.Str, &r.NullableStr, &r.TWithTz, &r.TWithoutTz, &r.Tm).Scan(&r.ID)
267+
if err != nil {
268+
return errors.WithStack(err)
269+
}
270+
return nil
271+
}
272+
273+
// CreateOnConflictDoNothing inserts the T1 to the database.
274+
// If a conflict occurs (e.g., unique constraint violation), the insert is skipped without error.
275+
// Returns true if the row was inserted, false if it was skipped due to conflict.
276+
func (r *T1) CreateOnConflictDoNothing(ctx context.Context, db Queryer) (bool, error) {
277+
err := db.QueryRowContext(ctx,
278+
` + "`INSERT INTO t1 (i, str, nullable_str, t_with_tz, t_without_tz, tm) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING RETURNING id`" + `,
279+
&r.I, &r.Str, &r.NullableStr, &r.TWithTz, &r.TWithoutTz, &r.Tm).Scan(&r.ID)
280+
if err != nil {
281+
if err == sql.ErrNoRows {
282+
return false, nil
283+
}
284+
return false, errors.WithStack(err)
285+
}
286+
// Row was successfully inserted
287+
return true, nil
288+
}
289+
290+
// GetT1ByPkContext select the T1 from the database.
291+
func GetT1ByPkContext(ctx context.Context, db Queryer, pk0 int64) (*T1, error) {
292+
var r T1
293+
err := db.QueryRowContext(ctx,
294+
` + "`SELECT id, i, str, nullable_str, t_with_tz, t_without_tz, tm FROM t1 WHERE id = $1`" + `,
295+
pk0).Scan(&r.ID, &r.I, &r.Str, &r.NullableStr, &r.TWithTz, &r.TWithoutTz, &r.Tm)
296+
if err != nil {
297+
return nil, errors.WithStack(err)
298+
}
299+
return &r, nil
300+
}
301+
302+
`,
303+
},
304+
{
305+
table: tbls[1],
306+
expect: `// Create inserts the T2 to the database.
307+
func (r *T2) Create(db Queryer) error {
308+
return r.CreateContext(context.Background(), db)
309+
}
310+
311+
// GetT2ByPk select the T2 from the database.
312+
func GetT2ByPk(db Queryer, pk0 int64) (*T2, error) {
313+
return GetT2ByPkContext(context.Background(), db, pk0)
314+
}
315+
316+
// CreateContext inserts the T2 to the database.
317+
func (r *T2) CreateContext(ctx context.Context, db Queryer) error {
318+
err := db.QueryRowContext(ctx,
319+
` + "`INSERT INTO t2 (i, str, t_with_tz, t_without_tz) VALUES ($1, $2, $3, $4) RETURNING id`" + `,
320+
&r.I, &r.Str, &r.TWithTz, &r.TWithoutTz).Scan(&r.ID)
321+
if err != nil {
322+
return errors.WithStack(err)
323+
}
324+
return nil
325+
}
326+
327+
// CreateOnConflictDoNothing inserts the T2 to the database.
328+
// If a conflict occurs (e.g., unique constraint violation), the insert is skipped without error.
329+
// Returns true if the row was inserted, false if it was skipped due to conflict.
330+
func (r *T2) CreateOnConflictDoNothing(ctx context.Context, db Queryer) (bool, error) {
331+
err := db.QueryRowContext(ctx,
332+
` + "`INSERT INTO t2 (i, str, t_with_tz, t_without_tz) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING RETURNING id`" + `,
333+
&r.I, &r.Str, &r.TWithTz, &r.TWithoutTz).Scan(&r.ID)
334+
if err != nil {
335+
if err == sql.ErrNoRows {
336+
return false, nil
337+
}
338+
return false, errors.WithStack(err)
339+
}
340+
// Row was successfully inserted
341+
return true, nil
342+
}
343+
344+
// GetT2ByPkContext select the T2 from the database.
345+
func GetT2ByPkContext(ctx context.Context, db Queryer, pk0 int64) (*T2, error) {
346+
var r T2
347+
err := db.QueryRowContext(ctx,
348+
` + "`SELECT id, i, str, t_with_tz, t_without_tz FROM t2 WHERE id = $1`" + `,
349+
pk0).Scan(&r.ID, &r.I, &r.Str, &r.TWithTz, &r.TWithoutTz)
350+
if err != nil {
351+
return nil, errors.WithStack(err)
352+
}
353+
return &r, nil
354+
}
355+
`,
356+
},
357+
{
358+
table: tbls[2],
359+
expect: `// Create inserts the T3 to the database.
360+
func (r *T3) Create(db Queryer) error {
361+
return r.CreateContext(context.Background(), db)
362+
}
363+
364+
// GetT3ByPk select the T3 from the database.
365+
func GetT3ByPk(db Queryer, pk0 int64, pk1 int) (*T3, error) {
366+
return GetT3ByPkContext(context.Background(), db, pk0, pk1)
367+
}
368+
369+
// CreateContext inserts the T3 to the database.
370+
func (r *T3) CreateContext(ctx context.Context, db Queryer) error {
371+
err := db.QueryRowContext(ctx,
372+
` + "`INSERT INTO t3 (str, t_with_tz, t_without_tz) VALUES ($1, $2, $3) RETURNING id, i`" + `,
373+
&r.Str, &r.TWithTz, &r.TWithoutTz).Scan(&r.ID, &r.I)
374+
if err != nil {
375+
return errors.WithStack(err)
376+
}
377+
return nil
378+
}
379+
380+
// CreateOnConflictDoNothing inserts the T3 to the database.
381+
// If a conflict occurs (e.g., unique constraint violation), the insert is skipped without error.
382+
// Returns true if the row was inserted, false if it was skipped due to conflict.
383+
func (r *T3) CreateOnConflictDoNothing(ctx context.Context, db Queryer) (bool, error) {
384+
err := db.QueryRowContext(ctx,
385+
` + "`INSERT INTO t3 (str, t_with_tz, t_without_tz) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING RETURNING id, i`" + `,
386+
&r.Str, &r.TWithTz, &r.TWithoutTz).Scan(&r.ID, &r.I)
387+
if err != nil {
388+
if err == sql.ErrNoRows {
389+
return false, nil
390+
}
391+
return false, errors.WithStack(err)
392+
}
393+
// Row was successfully inserted
394+
return true, nil
395+
}
396+
397+
// GetT3ByPkContext select the T3 from the database.
398+
func GetT3ByPkContext(ctx context.Context, db Queryer, pk0 int64, pk1 int) (*T3, error) {
399+
var r T3
400+
err := db.QueryRowContext(ctx,
401+
` + "`SELECT id, i, str, t_with_tz, t_without_tz FROM t3 WHERE id = $1 AND i = $2`" + `,
402+
pk0, pk1).Scan(&r.ID, &r.I, &r.Str, &r.TWithTz, &r.TWithoutTz)
403+
if err != nil {
404+
return nil, errors.WithStack(err)
405+
}
406+
return &r, nil
407+
}
408+
`,
409+
},
410+
{
411+
table: tbls[3],
412+
expect: `// Create inserts the T4 to the database.
413+
func (r *T4) Create(db Queryer) error {
414+
return r.CreateContext(context.Background(), db)
415+
}
416+
417+
// GetT4ByPk select the T4 from the database.
418+
func GetT4ByPk(db Queryer, pk0 int, pk1 int) (*T4, error) {
419+
return GetT4ByPkContext(context.Background(), db, pk0, pk1)
420+
}
421+
422+
// CreateContext inserts the T4 to the database.
423+
func (r *T4) CreateContext(ctx context.Context, db Queryer) error {
424+
_, err := db.ExecContext(ctx,
425+
` + "`INSERT INTO t4 (id, i) VALUES ($1, $2)`" + `,
426+
&r.ID, &r.I)
427+
if err != nil {
428+
return errors.WithStack(err)
429+
}
430+
return nil
431+
}
432+
433+
// CreateOnConflictDoNothing inserts the T4 to the database.
434+
// If a conflict occurs (e.g., unique constraint violation), the insert is skipped without error.
435+
// Returns true if the row was inserted, false if it was skipped due to conflict.
436+
func (r *T4) CreateOnConflictDoNothing(ctx context.Context, db Queryer) (bool, error) {
437+
result, err := db.ExecContext(ctx,
438+
` + "`INSERT INTO t4 (id, i) VALUES ($1, $2) ON CONFLICT DO NOTHING`" + `,
439+
&r.ID, &r.I)
440+
if err != nil {
441+
return false, errors.WithStack(err)
442+
}
443+
rowsAffected, err := result.RowsAffected()
444+
if err != nil {
445+
return false, errors.WithStack(err)
446+
}
447+
return rowsAffected > 0, nil
448+
}
449+
450+
// GetT4ByPkContext select the T4 from the database.
451+
func GetT4ByPkContext(ctx context.Context, db Queryer, pk0 int, pk1 int) (*T4, error) {
452+
var r T4
453+
err := db.QueryRowContext(ctx,
454+
` + "`SELECT id, i FROM t4 WHERE id = $1 AND i = $2`" + `,
455+
pk0, pk1).Scan(&r.ID, &r.I)
456+
if err != nil {
457+
return nil, errors.WithStack(err)
458+
}
459+
return &r, nil
460+
}`,
461+
},
462+
}
463+
for _, tt := range tests {
464+
t.Run(tt.table.Name, func(t *testing.T) {
465+
st, err := PgTableToStruct(tt.table, &defaultTypeMapCfg, autoGenKeyCfg)
466+
if err != nil {
467+
t.Fatal(err)
468+
}
469+
src, err := PgExecuteDefaultMethodTmpl(&StructTmpl{Struct: st})
470+
if err != nil {
471+
t.Fatal(err)
472+
}
473+
474+
re1 := regexp.MustCompile(`\s`)
475+
476+
srcStr := string(src)
477+
if re1.ReplaceAllString(srcStr, "") != re1.ReplaceAllString(tt.expect, "") {
478+
t.Errorf("Expected generated code: %s, got: %s", tt.expect, srcStr)
479+
}
480+
})
481+
}
482+
}

funcmap.go

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ import (
66
)
77

88
var tmplFuncMap = template.FuncMap{
9-
"createInsertSQL": createInsertSQL,
10-
"createInsertParams": createInsertParams,
11-
"createInsertScan": createInsertScan,
12-
"createSelectByPkSQL": createSelectByPkSQL,
13-
"createSelectByPkFuncParams": createSelectByPkFuncParams,
14-
"createSelectByPkSQLParams": createSelectByPkSQLParams,
15-
"createSelectByPkScan": createSelectByPkScan,
9+
"createInsertSQL": createInsertSQL,
10+
"createInsertOnConflictDoNothingSQL": createInsertOnConflictDoNothingSQL,
11+
"createInsertParams": createInsertParams,
12+
"createInsertScan": createInsertScan,
13+
"createSelectByPkSQL": createSelectByPkSQL,
14+
"createSelectByPkFuncParams": createSelectByPkFuncParams,
15+
"createSelectByPkSQLParams": createSelectByPkSQLParams,
16+
"createSelectByPkScan": createSelectByPkScan,
1617
}
1718

1819
func createSelectByPkSQL(st *Struct) string {
@@ -153,3 +154,46 @@ func createInsertSQL(st *Struct) string {
153154
}
154155
return sql
155156
}
157+
158+
func createInsertOnConflictDoNothingSQL(st *Struct) string {
159+
var sql string
160+
sql = "INSERT INTO " + st.Table.Name + " ("
161+
162+
if len(st.Table.Columns) == 1 && st.Table.Columns[0].IsPrimaryKey && st.Table.AutoGenPk {
163+
sql = sql + st.Table.Columns[0].Name + ") VALUES (DEFAULT)"
164+
} else {
165+
var colNames []string
166+
for _, c := range st.Table.Columns {
167+
if c.IsPrimaryKey && st.Table.AutoGenPk {
168+
continue
169+
} else {
170+
colNames = append(colNames, c.Name)
171+
}
172+
}
173+
sql = sql + flatten(colNames, ", ") + ") VALUES ("
174+
175+
var fieldNames []string
176+
for _, f := range st.Fields {
177+
if f.Column.IsPrimaryKey && st.Table.AutoGenPk {
178+
continue
179+
} else {
180+
fieldNames = append(fieldNames, f.Name)
181+
}
182+
}
183+
sql = sql + placeholders(fieldNames) + ")"
184+
}
185+
186+
sql = sql + " ON CONFLICT DO NOTHING"
187+
188+
if st.Table.AutoGenPk {
189+
sql = sql + " RETURNING "
190+
for i, c := range st.Table.PrimaryKeys {
191+
if i == 0 {
192+
sql = sql + c.Name
193+
} else {
194+
sql = sql + ", " + c.Name
195+
}
196+
}
197+
}
198+
return sql
199+
}

0 commit comments

Comments
 (0)