pglisten/pglisten.go

197 lines
3.5 KiB
Go

package pglisten
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/lib/pq"
"strings"
"time"
)
var (
db *sql.DB
listener *pq.Listener
)
var (
channel_name string
connection_info string
queue_table string
spawn_event string
)
type Message struct {
ID string `json:"id"`
Failed bool `json:"failed"`
Trigger string `json:"trigger"`
Payload map[string]any
}
type Handler func(Message) bool
func parse(msg string) (m Message) {
if err := json.Unmarshal([]byte(msg), &m); err != nil {
log(err.Error())
}
return m
}
func process(msg string, fn Handler) {
m := parse(msg)
if m.Failed {
log("pglisten: Message", m.ID, "is marked as failed. You fix it.")
return
}
if fn(m) {
drop(m)
} else {
failed(m)
}
}
func drop(m Message) (err error) {
_, err = db.Exec(fmt.Sprintf(`
delete from %s where message->>'id' = '%s'`, queue_table, m.ID))
if err != nil {
log(err.Error())
return err
}
log(fmt.Sprintf(`deleted notifications: %s`, m.ID))
return nil
}
func failed(m Message) (err error) {
_, err = db.Exec(fmt.Sprintf(`
update %s set
message = message || jsonb_build_object('failed', true)
where message->>'id' = '%s'`, queue_table, m.ID))
if err != nil {
log(err.Error())
return err
}
log(fmt.Sprintf(`updated notification as failed: %s`, m.ID))
return nil
}
func queued(f Handler) (err error) {
log("Checking for queued notifications")
events, err := db.Query(fmt.Sprintf("select message from %s", queue_table))
if err != nil {
return err
}
for events.Next() {
var msg string
if err := events.Scan(&msg); err != nil {
return err
}
go process(msg, f)
}
return nil
}
func log(s ...string) {
fmt.Println(time.Now().Format(time.RFC3339), strings.Join(s, " "))
}
func connect() (err error) {
db, err = sql.Open("postgres", connection_info)
return err
}
func reconnect() (err error) {
if err := connect(); err != nil {
log("Failed to reconnect!", err.Error())
}
return err
}
func listen() (err error) {
log("Listening on channel:", channel_name)
h := func(ev pq.ListenerEventType, err error) {
if err != nil {
log("pglisten error handler:", err.Error())
}
}
listener = pq.NewListener(connection_info, time.Second, time.Minute, h)
if err := listener.Listen(channel_name); err != nil {
log(err.Error())
}
return nil
}
func check() (err error) {
var one int
t := db.QueryRow(fmt.Sprintf(`select 1 from information_schema.tables where table_name = '%s'`, queue_table))
if err := t.Scan(&one); err != nil {
log("pglisten check for queue table:", queue_table)
log(err.Error())
return err
}
r := db.QueryRow(fmt.Sprintf(`select 1 from pg_proc where proname = '%s'`, spawn_event))
if err := r.Scan(&one); err != nil {
log("pglisten check for", spawn_event, "function:", err.Error())
return err
}
return nil
}
func On(conninfo string, chnl string, qtable string, spawn string, fn Handler) {
connection_info = conninfo
if err := connect(); err != nil {
panic(err)
}
spawn_event = spawn
queue_table = qtable
if err := check(); err != nil {
panic(err)
}
channel_name = chnl
if err := listen(); err != nil {
panic(err)
}
go queued(fn)
for {
select {
case m := <-listener.Notify:
go func() {
if m == nil {
log("nil message. Reconnecting?")
} else {
process(m.Extra, fn)
}
}()
case <-time.After(10 * time.Second):
go func() {
if err := listener.Ping(); err != nil {
log("Failed to ping the database!", err.Error())
reconnect()
}
}()
}
}
}