package pattern import ( "fmt" "go/ast" "go/token" "go/types" "reflect" "golang.org/x/exp/typeparams" ) var astTypes = map[string]reflect.Type{ "Ellipsis": reflect.TypeOf(ast.Ellipsis{}), "RangeStmt": reflect.TypeOf(ast.RangeStmt{}), "AssignStmt": reflect.TypeOf(ast.AssignStmt{}), "IndexExpr": reflect.TypeOf(ast.IndexExpr{}), "IndexListExpr": reflect.TypeOf(typeparams.IndexListExpr{}), "Ident": reflect.TypeOf(ast.Ident{}), "ValueSpec": reflect.TypeOf(ast.ValueSpec{}), "GenDecl": reflect.TypeOf(ast.GenDecl{}), "BinaryExpr": reflect.TypeOf(ast.BinaryExpr{}), "ForStmt": reflect.TypeOf(ast.ForStmt{}), "ArrayType": reflect.TypeOf(ast.ArrayType{}), "DeferStmt": reflect.TypeOf(ast.DeferStmt{}), "MapType": reflect.TypeOf(ast.MapType{}), "ReturnStmt": reflect.TypeOf(ast.ReturnStmt{}), "SliceExpr": reflect.TypeOf(ast.SliceExpr{}), "StarExpr": reflect.TypeOf(ast.StarExpr{}), "UnaryExpr": reflect.TypeOf(ast.UnaryExpr{}), "SendStmt": reflect.TypeOf(ast.SendStmt{}), "SelectStmt": reflect.TypeOf(ast.SelectStmt{}), "ImportSpec": reflect.TypeOf(ast.ImportSpec{}), "IfStmt": reflect.TypeOf(ast.IfStmt{}), "GoStmt": reflect.TypeOf(ast.GoStmt{}), "Field": reflect.TypeOf(ast.Field{}), "SelectorExpr": reflect.TypeOf(ast.SelectorExpr{}), "StructType": reflect.TypeOf(ast.StructType{}), "KeyValueExpr": reflect.TypeOf(ast.KeyValueExpr{}), "FuncType": reflect.TypeOf(ast.FuncType{}), "FuncLit": reflect.TypeOf(ast.FuncLit{}), "FuncDecl": reflect.TypeOf(ast.FuncDecl{}), "ChanType": reflect.TypeOf(ast.ChanType{}), "CallExpr": reflect.TypeOf(ast.CallExpr{}), "CaseClause": reflect.TypeOf(ast.CaseClause{}), "CommClause": reflect.TypeOf(ast.CommClause{}), "CompositeLit": reflect.TypeOf(ast.CompositeLit{}), "EmptyStmt": reflect.TypeOf(ast.EmptyStmt{}), "SwitchStmt": reflect.TypeOf(ast.SwitchStmt{}), "TypeSwitchStmt": reflect.TypeOf(ast.TypeSwitchStmt{}), "TypeAssertExpr": reflect.TypeOf(ast.TypeAssertExpr{}), "TypeSpec": reflect.TypeOf(ast.TypeSpec{}), "InterfaceType": reflect.TypeOf(ast.InterfaceType{}), "BranchStmt": reflect.TypeOf(ast.BranchStmt{}), "IncDecStmt": reflect.TypeOf(ast.IncDecStmt{}), "BasicLit": reflect.TypeOf(ast.BasicLit{}), } func ASTToNode(node interface{}) Node { switch node := node.(type) { case *ast.File: panic("cannot convert *ast.File to Node") case nil: return Nil{} case string: return String(node) case token.Token: return Token(node) case *ast.ExprStmt: return ASTToNode(node.X) case *ast.BlockStmt: if node == nil { return Nil{} } return ASTToNode(node.List) case *ast.FieldList: if node == nil { return Nil{} } return ASTToNode(node.List) case *ast.BasicLit: if node == nil { return Nil{} } case *ast.ParenExpr: return ASTToNode(node.X) } if node, ok := node.(ast.Node); ok { name := reflect.TypeOf(node).Elem().Name() T, ok := structNodes[name] if !ok { panic(fmt.Sprintf("internal error: unhandled type %T", node)) } if reflect.ValueOf(node).IsNil() { return Nil{} } v := reflect.ValueOf(node).Elem() objs := make([]Node, T.NumField()) for i := 0; i < T.NumField(); i++ { f := v.FieldByName(T.Field(i).Name) objs[i] = ASTToNode(f.Interface()) } n, err := populateNode(name, objs, false) if err != nil { panic(fmt.Sprintf("internal error: %s", err)) } return n } s := reflect.ValueOf(node) if s.Kind() == reflect.Slice { if s.Len() == 0 { return List{} } if s.Len() == 1 { return ASTToNode(s.Index(0).Interface()) } tail := List{} for i := s.Len() - 1; i >= 0; i-- { head := ASTToNode(s.Index(i).Interface()) l := List{ Head: head, Tail: tail, } tail = l } return tail } panic(fmt.Sprintf("internal error: unhandled type %T", node)) } func NodeToAST(node Node, state State) interface{} { switch node := node.(type) { case Binding: v, ok := state[node.Name] if !ok { // really we want to return an error here panic("XXX") } switch v := v.(type) { case types.Object: return &ast.Ident{Name: v.Name()} default: return v } case Builtin, Any, Object, Function, Not, Or: panic("XXX") case List: if (node == List{}) { return []ast.Node{} } x := []ast.Node{NodeToAST(node.Head, state).(ast.Node)} x = append(x, NodeToAST(node.Tail, state).([]ast.Node)...) return x case Token: return token.Token(node) case String: return string(node) case Nil: return nil } name := reflect.TypeOf(node).Name() T, ok := astTypes[name] if !ok { panic(fmt.Sprintf("internal error: unhandled type %T", node)) } v := reflect.ValueOf(node) out := reflect.New(T) for i := 0; i < T.NumField(); i++ { fNode := v.FieldByName(T.Field(i).Name) if (fNode == reflect.Value{}) { continue } fAST := out.Elem().FieldByName(T.Field(i).Name) switch fAST.Type().Kind() { case reflect.Slice: c := reflect.ValueOf(NodeToAST(fNode.Interface().(Node), state)) if c.Kind() != reflect.Slice { // it's a single node in the pattern, we have to wrap // it in a slice slice := reflect.MakeSlice(fAST.Type(), 1, 1) slice.Index(0).Set(c) c = slice } switch fAST.Interface().(type) { case []ast.Node: switch cc := c.Interface().(type) { case []ast.Node: fAST.Set(c) case []ast.Expr: var slice []ast.Node for _, el := range cc { slice = append(slice, el) } fAST.Set(reflect.ValueOf(slice)) default: panic("XXX") } case []ast.Expr: switch cc := c.Interface().(type) { case []ast.Node: var slice []ast.Expr for _, el := range cc { slice = append(slice, el.(ast.Expr)) } fAST.Set(reflect.ValueOf(slice)) case []ast.Expr: fAST.Set(c) default: panic("XXX") } default: panic("XXX") } case reflect.Int: c := reflect.ValueOf(NodeToAST(fNode.Interface().(Node), state)) switch c.Kind() { case reflect.String: tok, ok := tokensByString[c.Interface().(string)] if !ok { // really we want to return an error here panic("XXX") } fAST.SetInt(int64(tok)) case reflect.Int: fAST.Set(c) default: panic(fmt.Sprintf("internal error: unexpected kind %s", c.Kind())) } default: r := NodeToAST(fNode.Interface().(Node), state) if r != nil { fAST.Set(reflect.ValueOf(r)) } } } return out.Interface().(ast.Node) }