douqiao2471 2019-05-12 13:46
浏览 17
已采纳

具有可变输入/输出类型的通用函数

Just playing with aws sdk for go. When listing resources of different types I tend to have alot of very similar functions like the two in the example bellow. Is there a way to rewrite them as one generic function that will return a specific type depending on what is passed on as param?

Something like:

func generic(session, funcToCall, t, input) (interface{}, error) {}

currently I have to do this (functionality is the same just types change):

func getVolumes(s *session.Session) ([]*ec2.Volume, error) {

    client := ec2.New(s)

    t := []*ec2.Volume{}
    input := ec2.DescribeVolumesInput{}

    for {
        result, err := client.DescribeVolumes(&input)
        if err != nil {
            return nil, err
        }

        t = append(t, result.Volumes...)

        if result.NextToken != nil {
            input.NextToken = result.NextToken
        } else {
            break
        }
    }
    return t, nil
}

func getVpcs(s *session.Session) ([]*ec2.Vpc, error) {

    client := ec2.New(s)

    t := []*ec2.Vpc{}
    input := ec2.DescribeVpcsInput{}

    for {
        result, err := client.DescribeVpcs(&input)
        if err != nil {
            return nil, err
        }

        t = append(t, result.Vpcs...)

        if result.NextToken != nil {
            input.NextToken = result.NextToken
        } else {
            break
        }
    }
    return t, nil
} 
  • 写回答

2条回答 默认 最新

  • down2323 2019-05-12 19:58
    关注

    Because you only deal with functions it is possible to use the reflect package to generate functions at runtime.

    Using the object type (Volume, Vpc) it is possible to derive all subsequents information to provide a fully generic implementation that is really dry, at the extent at the being more complex and slower.

    It is untested, you are welcome to help in testing and fixing it, but something like this should put you on the track

    https://play.golang.org/p/mGjtYVG2OZS

    The registry idea come from this answer https://stackoverflow.com/a/23031445/4466350

    for reference the golang documentation of the reflect package is at https://golang.org/pkg/reflect/

    package main
    
    import (
        "errors"
        "fmt"
        "reflect"
    )
    
    func main() {
        fmt.Printf("%T
    ", getter(Volume{}))
        fmt.Printf("%T
    ", getter(Vpc{}))
    }
    
    type DescribeVolumesInput struct{}
    type DescribeVpcs struct{}
    
    type Volume struct{}
    type Vpc struct{}
    
    type Session struct{}
    
    type Client struct{}
    
    func New(s *Session) Client { return Client{} }
    
    var typeRegistry = make(map[string]reflect.Type)
    
    func init() {
        some := []interface{}{DescribeVolumesInput{}, DescribeVpcs{}}
        for _, v := range some {
            typeRegistry[fmt.Sprintf("%T", v)] = reflect.TypeOf(v)
        }
    }
    
    var errV = errors.New("")
    var errType = reflect.ValueOf(&errV).Elem().Type()
    var zeroErr = reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())
    var nilErr = []reflect.Value{zeroErr}
    
    func getter(of interface{}) interface{} {
    
        outType := reflect.SliceOf(reflect.PtrTo(reflect.TypeOf(of)))
        fnType := reflect.FuncOf([]reflect.Type{reflect.TypeOf(new(Session))}, []reflect.Type{outType, errType}, false)
        fnBody := func(input []reflect.Value) []reflect.Value {
    
            client := reflect.ValueOf(New).Call(input)[0]
    
            t := reflect.MakeSlice(outType, 0, 0)
            name := fmt.Sprintf("Describe%TsInput", of)
            descInput := reflect.New(typeRegistry[name]).Elem()
    
            mName := fmt.Sprintf("Describe%Ts", of)
            meth := client.MethodByName(mName)
            if !meth.IsValid() {
                return []reflect.Value{
                    t,
                    reflect.ValueOf(fmt.Errorf("no such method %q", mName)),
                }
            }
            for {
                out := meth.Call([]reflect.Value{descInput.Addr()})
                if len(out) > 0 {
                    errOut := out[len(out)-1]
                    if errOut.Type().Implements(errType) && errOut.IsNil() == false {
                        return []reflect.Value{t, errOut}
                    }
                }
                result := out[1]
                fName := fmt.Sprintf("%Ts", of)
                if x := result.FieldByName(fName); x.IsValid() {
                    t = reflect.AppendSlice(t, x)
                } else {
                    return []reflect.Value{
                        t,
                        reflect.ValueOf(fmt.Errorf("field not found %q", fName)),
                    }
                }
    
                if x := result.FieldByName("NextToken"); x.IsValid() {
                    descInput.FieldByName("NextToken").Set(x)
                } else {
                    break
                }
            }
            return []reflect.Value{t, zeroErr}
        }
        fn := reflect.MakeFunc(fnType, fnBody)
        return fn.Interface()
    }
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥15 phython如何实现以下功能?查找同一用户名的消费金额合并—
  • ¥15 孟德尔随机化怎样画共定位分析图
  • ¥18 模拟电路问题解答有偿速度
  • ¥15 CST仿真别人的模型结果仿真结果S参数完全不对
  • ¥15 误删注册表文件致win10无法开启
  • ¥15 请问在阿里云服务器中怎么利用数据库制作网站
  • ¥60 ESP32怎么烧录自启动程序
  • ¥50 html2canvas超出滚动条不显示
  • ¥15 java业务性能问题求解(sql,业务设计相关)
  • ¥15 52810 尾椎c三个a 写蓝牙地址