sqlfunc

package module
v0.0.0-...-ecbd39a Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 28, 2024 License: Apache-2.0 Imports: 5 Imported by: 0

README

sqlfunc - Stronger typing for database/sql prepared statements

Go Reference CI Coverage Go Report Card

Status

Production ready.

Check code coverage by the testsuite.

Known issues
  • There is a speed/memory penalty in using the sqlfunc wrappers (check go test -bench B -benchmem github.com/dolmen-go/sqlfunc). It is recommended to do your own benchmarks. There are plans to fix that (add a code generator to reduce cost of runtime reflect), but no release date planned for this complex feature.

License

Copyright 2023 Olivier Mengué

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Documentation

Overview

Package sqlfunc provides utilities to wrap SQL prepared statements with strongly typed Go functions.

You just have to define the function signature you need:

var whoami func(context.Context) (string, error)

and the SQL statement that this function wraps:

close, err := sqlfunc.QueryRow(ctx, db, `SELECT USER()`, &whoami)  // MySQL example
defer close()

You can now use the function:

user, err := whoami(ctx)
fmt.Println("Connected as", user)

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func Exec

func Exec(ctx context.Context, db PrepareConn, query string, fnPtr interface{}) (close func() error, err error)

Exec prepares an SQL statement and creates a function wrapping sql.Stmt.ExecContext.

fnPtr is a pointer to a func variable. The function signature tells how it will be called.

The first argument is a context.Context. If a *sql.Tx is given as the second argument, the statement will be localized to the transaction (using sql.Tx.StmtContext). The following arguments will be given as arguments to sql.Stmt.ExecContext.

The function will return an sql.Result and an error.

The returned func 'close' must be called once the statement is not needed anymore.

Example:

var f func(ctx context.Context, arg1 int64, arg2 string, arg3 sql.NullInt, arg4 *sql.Time) (sql.Result, error)
close1, err = sqlfunc.Exec(ctx, db, "SELECT ?, ?, ?, ?", &f)
// if err != nil ...
defer close1()
res, err = f(ctx, 1, "a", sql.NullInt{Valid: false}, time.Now())

Example with transaction:

var fTx func(ctx context, *sql.Tx, arg1 int64) (sql.Result, error)
close2, err = sqlfunc.Exec(ctx, db, "SELECT ?", &fTx)
// if err != nil ...
defer close2()

tx, err := db.BeginTxt()
// if err != nil ...
res, err := fTx(ctx, tx, 123)
// if err != nil ...
err = tx.Commit()
// if err != nil ...
Example
check := func(msg string, err error) {
	if err != nil {
		panic(fmt.Errorf("%s: %v", msg, err))
	}
}

ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
check("Open", err)
defer db.Close()

// POI = Point of Interest
_, err = db.ExecContext(ctx, `CREATE TABLE poi (lat DECIMAL, lon DECIMAL, name VARCHAR(255))`)
check("Create table", err)

// newPOI is the function that will call the INSERT statement
var newPOI func(ctx context.Context, lat float32, lon float32, name string) (sql.Result, error)
closeStmt, err := sqlfunc.Exec(
	ctx, db,
	`INSERT INTO poi (lat, lon, name) VALUES (?, ?, ?)`,
	&newPOI,
)
check("Prepare newPOI", err)
defer closeStmt()

// To call the prepared statement we use the strongly typed function
_, err = newPOI(ctx, 48.8016, 2.1204, "Château de Versailles")
check("newPOI", err)

var name string
err = db.QueryRow(`` +
	`SELECT name` +
	` FROM poi` +
	` WHERE lat BETWEEN 48.8015 AND 48.8017` +
	` AND lon BETWEEN 2.1203 AND 2.1205`,
).Scan(&name)
check("Query", err)

fmt.Println(name)

var getPOICoord func(ctx context.Context, name string) (lat float64, lon float64, err error)
closeStmt, err = sqlfunc.QueryRow(
	ctx, db, ``+
		`SELECT lat, lon`+
		` FROM poi`+
		` WHERE name = ?`,
	&getPOICoord,
)
check("Prepare getPOICoord", err)
defer closeStmt()

_, _, err = getPOICoord(ctx, "Trifoully-les-Oies")
if err != sql.ErrNoRows {
	log.Printf("getPOICoord should fail with sql.ErrNoRows")
	return
}

lat, lon, err := getPOICoord(ctx, "Château de Versailles")
if err != nil {
	log.Printf("getPOICoord should succeed but %q", err)
	return
}
fmt.Printf("%.4f, %.4f\n", lat, lon)
Output:

Château de Versailles
48.8016, 2.1204
Example (WithTx)

ExampleExec_withTx shows support for transactions.

check := func(msg string, err error) {
	if err != nil {
		panic(fmt.Errorf("%s: %v", msg, err))
	}
}

ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
check("Open", err)
defer db.Close()

conn, err := db.Conn(ctx)
check("Conn", err)

// POI = Point of Interest
_, err = conn.ExecContext(ctx, `CREATE TABLE poi (lat DECIMAL, lon DECIMAL, name VARCHAR(255))`)
check("Create table", err)

var countPOI func(ctx context.Context) (int64, error)
closeCountPOI, err := sqlfunc.QueryRow(
	ctx, conn,
	`SELECT COUNT(*) FROM poi`,
	&countPOI,
)
check("Prepare countPOI", err)
defer closeCountPOI()

var queryNames func(ctx context.Context) (*sql.Rows, error)
closeQueryNames, err := sqlfunc.Query(
	ctx, conn,
	`SELECT name FROM poi ORDER BY name`,
	&queryNames,
)
check("Prepare queryNames", err)
defer closeQueryNames()

nbPOI, err := countPOI(ctx)
check("countPOI", err)

fmt.Println("countPOI before insert:", nbPOI)

var insertPOI func(ctx context.Context, tx *sql.Tx, lat, lon float64, name string) (sql.Result, error)
closeInsertPOI, err := sqlfunc.Exec(
	ctx, db,
	`INSERT INTO poi (lat, lon, name) VALUES (?, ?, ?)`,
	&insertPOI,
)
check("Prepare insertPOI", err)
defer closeInsertPOI()

tx, err := conn.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
check("BeginTx", err)
defer tx.Rollback()

res, err := insertPOI(ctx, tx, 48.8016, 2.1204, "Château de Versailles")
check("newPOI", err)

nbRows, err := res.RowsAffected()
check("RowsAffected", err)
fmt.Println("Rows inserted:", nbRows)

res, err = insertPOI(ctx, tx, 47.2009, 0.6317, "Villeperdue")
check("newPOI", err)

nbRows, err = res.RowsAffected()
check("RowsAffected", err)
fmt.Println("Rows inserted:", nbRows)

nbPOI, err = countPOI(ctx)
check("countPOI", err)
fmt.Println("countPOI after inserts:", nbPOI)

rows, err := queryNames(ctx)
check("queryNames", err)
var names []string
err = sqlfunc.ForEach(rows, func(name string) {
	names = append(names, name)
})
check("ForEach", err)
fmt.Println("names:", names)

tx.Rollback()

nbPOI, err = countPOI(ctx)
check("countPOI after rollback", err)

fmt.Println("countPOI after rollback:", nbPOI)
Output:

countPOI before insert: 0
Rows inserted: 1
Rows inserted: 1
countPOI after inserts: 2
names: [Château de Versailles Villeperdue]
countPOI after rollback: 0

func ForEach

func ForEach(rows *sql.Rows, callback interface{}) error

ForEach iterates an *sql.Rows, scans the values of the row and calls the given callback function with the values.

The callback receives the scanned columns values as arguments and may return an error or a bool (false) to stop iterating.

rows are closed before returning.

Example
ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
if err != nil {
	log.Printf("Open: %v", err)
	return
}
defer db.Close()

rows, err := db.QueryContext(ctx, ``+
	`SELECT 1`+
	` UNION ALL`+
	` SELECT 2`)
if err != nil {
	log.Printf("Query: %v", err)
	return
}

err = sqlfunc.ForEach(rows, func(n int) {
	fmt.Println(n)
})
if err != nil {
	log.Printf("ScanRows: %v", err)
	return
}

fmt.Println("Done.")
Output:

1
2
Done.
Example (ReturnBool)
ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
if err != nil {
	log.Printf("Open: %v", err)
	return
}
defer db.Close()

rows, err := db.QueryContext(ctx, ``+
	`SELECT 1`+
	` UNION ALL`+
	` SELECT 2`+
	` UNION ALL`+
	` SELECT 3`)
if err != nil {
	log.Printf("Query: %v", err)
	return
}

err = sqlfunc.ForEach(rows, func(n int) bool {
	fmt.Println(n)
	return n < 2 // Stop iterating on n == 2
})
if err != nil {
	log.Printf("ScanRows: %v", err)
	return
}

fmt.Println("Done.")
Output:

1
2
Done.
Example (ReturnError)
ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
if err != nil {
	log.Printf("Open: %v", err)
	return
}
defer db.Close()

rows, err := db.QueryContext(ctx, ``+
	`SELECT 1`+
	` UNION ALL`+
	` SELECT 2`+
	` UNION ALL`+
	` SELECT 3`)
if err != nil {
	log.Printf("Query: %v", err)
	return
}

err = sqlfunc.ForEach(rows, func(n int) error {
	fmt.Println(n)
	if n == 2 {
		return io.EOF
	}
	return nil
})
if err != nil && !errors.Is(err, io.EOF) {
	log.Printf("ScanRows: %v", err)
	return
}

fmt.Println("Done.")
Output:

1
2
Done.

func Query

func Query(ctx context.Context, db PrepareConn, query string, fnPtr interface{}) (close func() error, err error)

Query prepares an SQL statement and creates a function wrapping sql.Stmt.QueryContext.

fnPtr is a pointer to a func variable. The function signature tells how it will be called.

The first argument is a context.Context. If an *sql.Tx is given as the second argument, the statement will be localized to the transaction (using sql.Tx.StmtContext). The following arguments will be given as arguments to sql.Stmt.QueryRowContext.

The function will return an *sql.Rows and an error.

The returned func 'close' must be called once the statement is not needed anymore.

Example
check := func(msg string, err error) {
	if err != nil {
		panic(fmt.Errorf("%s: %v", msg, err))
	}
}

ctx := context.Background()
db, err := sql.Open(sqliteDriver, "file:testdata/poi.db?mode=ro&immutable=1")
check("Open", err)
defer db.Close()

var queryNames func(ctx context.Context) (*sql.Rows, error)
closeQueryNames, err := sqlfunc.Query(
	ctx, db,
	`SELECT name FROM poi ORDER BY name`,
	&queryNames,
)
check("Prepare queryNames", err)
defer closeQueryNames()

rows, err := queryNames(ctx)
check("queryNames", err)
err = sqlfunc.ForEach(rows, func(name string) {
	fmt.Println("-", name)
})
check("read rows", err)
Output:

- Château de Versailles
- Villeperdue
Example (WithArgs)
check := func(msg string, err error) {
	if err != nil {
		panic(fmt.Errorf("%s: %v", msg, err))
	}
}

ctx := context.Background()
db, err := sql.Open(sqliteDriver, "file:testdata/poi.db?mode=ro&immutable=1")
check("Open", err)
defer db.Close()

var queryByName func(ctx context.Context, name string) (*sql.Rows, error)
closeQueryByName, err := sqlfunc.Query(
	ctx, db,
	`SELECT lat, lon FROM poi WHERE name = ?`,
	&queryByName,
)
check("Prepare queryByName", err)
defer closeQueryByName()

rows, err := queryByName(ctx, "Château de Versailles")
check("queryByName", err)
err = sqlfunc.ForEach(rows, func(lat, lon float64) {
	fmt.Printf("(%.4f %.4f)\n", lat, lon)
})
check("read rows", err)
Output:

(48.8016 2.1204)

func QueryRow

func QueryRow(ctx context.Context, db PrepareConn, query string, fnPtr interface{}) (close func() error, err error)

QueryRow prepares an SQL statement and creates a function wrapping sql.Stmt.QueryRowContext and sql.Row.Scan.

fnPtr is a pointer to a func variable. The function signature tells how it will be called.

The first argument is a context.Context. If a *sql.Tx is given as the second argument, the statement will be localized to the transaction (using sql.Tx.StmtContext). The following arguments will be given as arguments to sql.Stmt.QueryRowContext.

The function will return values scanned from the sql.Row and an error.

The returned func 'close' must be called once the statement is not needed anymore.

Example (WithArgs)
check := func(msg string, err error) {
	if err != nil {
		panic(fmt.Errorf("%s: %v", msg, err))
	}
}

ctx := context.Background()
db, err := sql.Open(sqliteDriver, "file:testdata/poi.db?mode=ro&immutable=1")
check("Open", err)
defer db.Close()

var queryByName func(ctx context.Context, name string) (lat, lon float64, err error)
closeQueryByName, err := sqlfunc.QueryRow(
	ctx, db,
	`SELECT lat, lon FROM poi WHERE name = ?`,
	&queryByName,
)
check("Prepare queryByName", err)
defer closeQueryByName()

lat, lon, err := queryByName(ctx, "Château de Versailles")
check("queryByName", err)
fmt.Printf("(%.4f %.4f)\n", lat, lon)
Output:

(48.8016 2.1204)

func Scan

func Scan(fnPtr interface{})

Scan allows to define a function that will scan one row from an *sql.Rows.

The signature of the function defines how the column values are retrieved into variables. Two styles are available:

  • as pointer variables (like sql.Rows.Scan): func (rows *sql.Rows, pval1 *int, pval2 *string) error
  • as returned values (implies copies): func (rows *sql.Rows) (val1 int, val2 string, err error)
Example
ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
if err != nil {
	log.Printf("Open: %v", err)
	return
}
defer db.Close()

var scan1 func(*sql.Rows, *int) error
rows, err := db.QueryContext(ctx, ``+
	`SELECT 1`+
	` UNION ALL`+
	` SELECT 2`)
if err != nil {
	log.Printf("Query1: %v", err)
	return
}
defer rows.Close()

sqlfunc.Scan(&scan1)

var values1 []int
for rows.Next() {
	var n int
	if err = scan1(rows, &n); err != nil {
		log.Printf("Scan1: %v", err)
		return
	}
	values1 = append(values1, n)
}
if err = rows.Err(); err != nil {
	log.Printf("Next1: %v", err)
}
fmt.Println(values1)

var scan2 func(*sql.Rows) (string, error)
rows, err = db.QueryContext(ctx, ``+
	`SELECT 'a'`+
	` UNION ALL`+
	` SELECT 'b'`)
if err != nil {
	log.Printf("Query2: %v", err)
	return
}
defer rows.Close()

sqlfunc.Scan(&scan2)

var values2 []string
for rows.Next() {
	s, err := scan2(rows)
	if err != nil {
		log.Printf("Scan2: %v", err)
		return
	}
	values2 = append(values2, s)
}
if err = rows.Err(); err != nil {
	log.Printf("Next2: %v", err)
}
fmt.Println(values2)
Output:

[1 2]
[a b]
Example (Any)
ctx := context.Background()
db, err := sql.Open(sqliteDriver, ":memory:")
if err != nil {
	log.Printf("Open: %v", err)
	return
}
defer db.Close()

var scan1 func(*sql.Rows, *interface{}) error
rows, err := db.QueryContext(ctx, ``+
	`SELECT 1`+
	` UNION ALL`+
	` SELECT NULL`+
	` UNION ALL`+
	` SELECT 'a'`)
if err != nil {
	log.Printf("Query1: %v", err)
	return
}
defer rows.Close()

sqlfunc.Scan(&scan1)

for rows.Next() {
	var v interface{}
	if err = scan1(rows, &v); err != nil {
		log.Printf("Scan1: %v", err)
		return
	}
	fmt.Printf("%T %#[1]v\n", v)
}
if err = rows.Err(); err != nil {
	log.Printf("Next1: %v", err)
}
Output:

int64 1
<nil> <nil>
string "a"

Types

type PrepareConn

type PrepareConn interface {
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}

PrepareConn is a subset of *database/sql.DB, *database/sql.Conn or *database/sql.Tx.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL