goduckdbudf

package main

import (
    "context"
    "database/sql"
    "database/sql/driver"
    "errors"
    "fmt"
    "github.com/twpayne/go-geos"

    "github.com/marcboeker/go-duckdb/v2"
)

// Overload my_length with two user-defined scalar functions.
// varcharLen takes a VARCHAR as its input parameter.
// listLen takes a LIST(ANY) as its input parameter.

type (
    lineStringLen struct{}
    listLen       struct{}
)

func lineStringLenFn(values []driver.Value) (any, error) {
    str := values[0].(string)
    geom, err := geos.NewGeomFromWKT(str)
    if err != nil {
        panic(err)
    }
    return geom.Length(), nil
}

func (*lineStringLen) Config() duckdb.ScalarFuncConfig {
    inputTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_VARCHAR)
    check(err)
    resultTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_FLOAT)
    check(err)

    return duckdb.ScalarFuncConfig{
        InputTypeInfos: []duckdb.TypeInfo{inputTypeInfo},
        ResultTypeInfo: resultTypeInfo,
    }
}

func (*lineStringLen) Executor() duckdb.ScalarFuncExecutor {
    return duckdb.ScalarFuncExecutor{RowExecutor: lineStringLenFn}
}

func listLenFn(values []driver.Value) (any, error) {
    list := values[0].([]any)
    return float64(len(list)), nil
}

func (*listLen) Config() duckdb.ScalarFuncConfig {
    anyTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_ANY)
    check(err)
    inputTypeInfo, err := duckdb.NewListInfo(anyTypeInfo)
    check(err)
    resultTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_FLOAT)
    check(err)

    return duckdb.ScalarFuncConfig{
        InputTypeInfos: []duckdb.TypeInfo{inputTypeInfo},
        ResultTypeInfo: resultTypeInfo,
    }
}

func (*listLen) Executor() duckdb.ScalarFuncExecutor {
    return duckdb.ScalarFuncExecutor{RowExecutor: listLenFn}
}

func myLengthScalarUDFSet() {
    db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
    check(err)

    c, err := db.Conn(context.Background())
    check(err)

    var lineStringLenUDF *lineStringLen
    var listUDF *listLen
    err = duckdb.RegisterScalarUDFSet(c, "length", lineStringLenUDF, listUDF)
    check(err)

    var length float64
    row := db.QueryRow(`SELECT length('LINESTRING Z (121.2573237499999 31.21475976 15.11999999999999, 121.2572921499999 31.21481678 15.07, 121.25728415 31.21483049 15.08, 121.2572641499999 31.2148651 15.06, 121.25725525 31.21488051 15.03999999999999, 121.2572359499999 31.21491392 15.02999999999999, 121.25722704 31.21492903 15.02999999999999, 121.25720794 31.21496204 15.02, 121.25719904 31.21497725 15.02999999999999, 121.25717964 31.21501036 15.03999999999999, 121.25717034 31.21502617 15.03999999999999, 121.25715094 31.21505897999999 15.05, 121.2571419399999 31.21507429 15.03999999999999, 121.25712283 31.215107 15.02, 121.25711393 31.21512191 15.02999999999999, 121.25709433 31.21515511999999 15.03999999999999, 121.25708563 31.21516973 15.03999999999999, 121.2570657299999 31.21520324 15.02999999999999, 121.2570566299999 31.21521853999999 15.02999999999999, 121.25703703 31.21525125999999 15, 121.2570275199999 31.21526696 14.99, 121.25700802 31.21529977999999 14.98, 121.25699932 31.21531398 14.97, 121.2569793199999 31.2153472 14.97, 121.25696972 31.21536289999999 14.98, 121.25695262 31.21539140999999 14.97, 121.25694242 31.21540802 14.93, 121.25692331 31.21543983 14.9, 121.25691541 31.21545263999999 14.89, 121.25689771 31.21548185 14.88, 121.2568890099999 31.21549595999999 14.86999999999999, 121.25687141 31.21552486999999 14.84, 121.2568626099999 31.21553918 14.83, 121.2568421099999 31.21557279 14.78999999999999, 121.2568330099999 31.21558758999999 14.78999999999999, 121.2568126 31.21562101 14.76, 121.2568123 31.21562150999999 14.76)') AS sum`)
    check(row.Scan(&length))
    fmt.Println(length)

    row = db.QueryRow(`SELECT line_length([1, 2, NULL, 4, NULL]) AS sum`)
    check(row.Scan(&length))
    fmt.Println(length)

    check(c.Close())
    check(db.Close())
}

// wrapSum takes a VARCHAR prefix, a VARCHAR suffix, and a variadic number of integer values.
// It computes the sum of the integer values. Then, it emits a VARCHAR by concatenating prefix || sum || suffix.

type wrapSum struct{}

func wrapSumFn(values []driver.Value) (any, error) {
    sum := int32(0)
    for i := 2; i < len(values); i++ {
        sum += values[i].(int32)
    }
    strSum := fmt.Sprintf("%d", sum)
    prefix := values[0].(string)
    suffix := values[1].(string)
    return prefix + strSum + suffix, nil
}

func (*wrapSum) Config() duckdb.ScalarFuncConfig {
    varcharTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_VARCHAR)
    check(err)
    intTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER)
    check(err)

    return duckdb.ScalarFuncConfig{
        InputTypeInfos:   []duckdb.TypeInfo{varcharTypeInfo, varcharTypeInfo},
        ResultTypeInfo:   varcharTypeInfo,
        VariadicTypeInfo: intTypeInfo,
    }
}

func (*wrapSum) Executor() duckdb.ScalarFuncExecutor {
    return duckdb.ScalarFuncExecutor{RowExecutor: wrapSumFn}
}

func wrapSumScalarUDF() {
    db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
    check(err)

    c, err := db.Conn(context.Background())
    check(err)

    var wrapSumUDF *wrapSum
    err = duckdb.RegisterScalarUDF(c, "wrap_sum", wrapSumUDF)
    check(err)

    var res string
    row := db.QueryRow(`SELECT wrap_sum('hello', ' world', 1, 2, 3, 4) AS sum`)
    check(row.Scan(&res))
    fmt.Println(res)
    if res != "hello10 world" {
        panic(errors.New("incorrect result"))
    }

    row = db.QueryRow(`SELECT wrap_sum('hello', ' world') AS sum`)
    check(row.Scan(&res))
    if res != "hello0 world" {
        panic(errors.New("incorrect result"))
    }

    check(c.Close())
    check(db.Close())
}

func main() {
    myLengthScalarUDFSet()
    wrapSumScalarUDF()
}

func check(args ...interface{}) {
    err := args[len(args)-1]
    if err != nil {
        panic(err)
    }
}

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容