douqiao2471 2019-05-12 13:46
浏览 17
已采纳

具有可变输入/输出类型的通用函数

Just playing with aws sdk for go. When listing resources of different types I tend to have alot of very similar functions like the two in the example bellow. Is there a way to rewrite them as one generic function that will return a specific type depending on what is passed on as param?

Something like:

func generic(session, funcToCall, t, input) (interface{}, error) {}

currently I have to do this (functionality is the same just types change):

func getVolumes(s *session.Session) ([]*ec2.Volume, error) {

    client := ec2.New(s)

    t := []*ec2.Volume{}
    input := ec2.DescribeVolumesInput{}

    for {
        result, err := client.DescribeVolumes(&input)
        if err != nil {
            return nil, err
        }

        t = append(t, result.Volumes...)

        if result.NextToken != nil {
            input.NextToken = result.NextToken
        } else {
            break
        }
    }
    return t, nil
}

func getVpcs(s *session.Session) ([]*ec2.Vpc, error) {

    client := ec2.New(s)

    t := []*ec2.Vpc{}
    input := ec2.DescribeVpcsInput{}

    for {
        result, err := client.DescribeVpcs(&input)
        if err != nil {
            return nil, err
        }

        t = append(t, result.Vpcs...)

        if result.NextToken != nil {
            input.NextToken = result.NextToken
        } else {
            break
        }
    }
    return t, nil
} 
  • 写回答

2条回答 默认 最新

  • down2323 2019-05-12 19:58
    关注

    Because you only deal with functions it is possible to use the reflect package to generate functions at runtime.

    Using the object type (Volume, Vpc) it is possible to derive all subsequents information to provide a fully generic implementation that is really dry, at the extent at the being more complex and slower.

    It is untested, you are welcome to help in testing and fixing it, but something like this should put you on the track

    https://play.golang.org/p/mGjtYVG2OZS

    The registry idea come from this answer https://stackoverflow.com/a/23031445/4466350

    for reference the golang documentation of the reflect package is at https://golang.org/pkg/reflect/

    package main
    
    import (
        "errors"
        "fmt"
        "reflect"
    )
    
    func main() {
        fmt.Printf("%T
    ", getter(Volume{}))
        fmt.Printf("%T
    ", getter(Vpc{}))
    }
    
    type DescribeVolumesInput struct{}
    type DescribeVpcs struct{}
    
    type Volume struct{}
    type Vpc struct{}
    
    type Session struct{}
    
    type Client struct{}
    
    func New(s *Session) Client { return Client{} }
    
    var typeRegistry = make(map[string]reflect.Type)
    
    func init() {
        some := []interface{}{DescribeVolumesInput{}, DescribeVpcs{}}
        for _, v := range some {
            typeRegistry[fmt.Sprintf("%T", v)] = reflect.TypeOf(v)
        }
    }
    
    var errV = errors.New("")
    var errType = reflect.ValueOf(&errV).Elem().Type()
    var zeroErr = reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())
    var nilErr = []reflect.Value{zeroErr}
    
    func getter(of interface{}) interface{} {
    
        outType := reflect.SliceOf(reflect.PtrTo(reflect.TypeOf(of)))
        fnType := reflect.FuncOf([]reflect.Type{reflect.TypeOf(new(Session))}, []reflect.Type{outType, errType}, false)
        fnBody := func(input []reflect.Value) []reflect.Value {
    
            client := reflect.ValueOf(New).Call(input)[0]
    
            t := reflect.MakeSlice(outType, 0, 0)
            name := fmt.Sprintf("Describe%TsInput", of)
            descInput := reflect.New(typeRegistry[name]).Elem()
    
            mName := fmt.Sprintf("Describe%Ts", of)
            meth := client.MethodByName(mName)
            if !meth.IsValid() {
                return []reflect.Value{
                    t,
                    reflect.ValueOf(fmt.Errorf("no such method %q", mName)),
                }
            }
            for {
                out := meth.Call([]reflect.Value{descInput.Addr()})
                if len(out) > 0 {
                    errOut := out[len(out)-1]
                    if errOut.Type().Implements(errType) && errOut.IsNil() == false {
                        return []reflect.Value{t, errOut}
                    }
                }
                result := out[1]
                fName := fmt.Sprintf("%Ts", of)
                if x := result.FieldByName(fName); x.IsValid() {
                    t = reflect.AppendSlice(t, x)
                } else {
                    return []reflect.Value{
                        t,
                        reflect.ValueOf(fmt.Errorf("field not found %q", fName)),
                    }
                }
    
                if x := result.FieldByName("NextToken"); x.IsValid() {
                    descInput.FieldByName("NextToken").Set(x)
                } else {
                    break
                }
            }
            return []reflect.Value{t, zeroErr}
        }
        fn := reflect.MakeFunc(fnType, fnBody)
        return fn.Interface()
    }
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

悬赏问题

  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 对于相关问题的求解与代码
  • ¥15 ubuntu子系统密码忘记
  • ¥15 信号傅里叶变换在matlab上遇到的小问题请求帮助
  • ¥15 保护模式-系统加载-段寄存器
  • ¥15 电脑桌面设定一个区域禁止鼠标操作
  • ¥15 求NPF226060磁芯的详细资料